diff --git a/ecc/bls12-377/fr/gkr/gkr.go b/ecc/bls12-377/fr/gkr/gkr.go deleted file mode 100644 index 3305e0dae..000000000 --- a/ecc/bls12-377/fr/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...fr.Element) fr.Element - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation fr.Element - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]fr.Element, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step fr.Element - - res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - - //defer the proof, return list of claims - evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []fr.Element{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func topologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := topologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]fr.Element, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]fr.Element, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []fr.Element) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/ecc/bls12-377/fr/gkr/gkr_test.go b/ecc/bls12-377/fr/gkr/gkr_test.go deleted file mode 100644 index 617c9bc39..000000000 --- a/ecc/bls12-377/fr/gkr/gkr_test.go +++ /dev/null @@ -1,828 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []fr.Element{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { - return func(t *testing.T, inputAssignments ...[]fr.Element) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []fr.Element{three}, five) - manager.add(wire, []fr.Element{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six fr.Element - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { - fullAssignments := make([][]fr.Element, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]fr.Element, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []fr.Element) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "../../../../internal/generator/gkr/test_vectors" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := topologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []fr.Element(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []fr.Element - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/ecc/bls12-377/fr/gkr/registry.go b/ecc/bls12-377/fr/gkr/registry.go deleted file mode 100644 index 64f13cc0a..000000000 --- a/ecc/bls12-377/fr/gkr/registry.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(fr.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]fr.Element, nbIn) - consts := make(fr.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - domain := fft.NewDomain(degreeBound) - // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) - x := fr.One() - for i := range p { - fIn[0] = x - for j := range consts { - fIn[j+1].Mul(&x, &consts[j]) - } - p[i] = f(fIn...) - - x.Mul(&x, &domain.Generator) - } - - // obtain p's coefficients - domain.FFTInverse(p, fft.DIF) - fft.BitReverse(p) - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/ecc/bls12-377/fr/poseidon2/gkrgates/gkrgates.go b/ecc/bls12-377/fr/poseidon2/gkrgates/gkrgates.go deleted file mode 100644 index 83717b122..000000000 --- a/ecc/bls12-377/fr/poseidon2/gkrgates/gkrgates.go +++ /dev/null @@ -1,248 +0,0 @@ -// Package gkrgates implements the Poseidon2 permutation gate for GKR -// -// This implementation is based on the [poseidon2] package, but exposes the -// primitives as gates for inclusion in GKR circuits. - -// TODO(@Tabaie @ThomasPiellard) generify once Poseidon2 parameters are known for all curves -package gkrgates - -import ( - "fmt" - "sync" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" -) - -// The GKR gates needed for proving Poseidon2 permutations - -// extKeySBoxGate applies the external matrix mul, then adds the round key, then applies the sBox -// because of its symmetry, we don't need to define distinct x1 and x2 versions of it -func extKeySBoxGate(roundKey *fr.Element) gkr.GateFunction { - return func(x ...fr.Element) fr.Element { - x[0]. - Double(&x[0]). - Add(&x[0], &x[1]). - Add(&x[0], roundKey) - return sBox2(x[0]) - } -} - -// intKeySBoxGate2 applies the second row of internal matrix mul, then adds the round key, then applies the sBox, returning the second element -func intKeySBoxGate2(roundKey *fr.Element) gkr.GateFunction { - return func(x ...fr.Element) fr.Element { - x[0].Add(&x[0], &x[1]) - x[1]. - Double(&x[1]). - Add(&x[1], &x[0]). - Add(&x[1], roundKey) - - return sBox2(x[1]) - } -} - -// extAddGate (x,y,z) -> Ext . (x,y) + z -func extAddGate(x ...fr.Element) fr.Element { - x[0]. - Double(&x[0]). - Add(&x[0], &x[1]). - Add(&x[0], &x[2]) - return x[0] -} - -// sBox2 is Permutation.sBox for t=2 -func sBox2(x fr.Element) fr.Element { - var y fr.Element - y.Square(&x).Square(&y).Square(&y).Square(&y).Mul(&x, &y) - return y -} - -// extKeyGate applies the external matrix mul, then adds the round key, then applies the sBox -// because of its symmetry, we don't need to define distinct x1 and x2 versions of it -func extKeyGate(roundKey *fr.Element) func(...fr.Element) fr.Element { - return func(x ...fr.Element) fr.Element { - x[0]. - Double(&x[0]). - Add(&x[0], &x[1]). - Add(&x[0], roundKey) - return x[0] - } -} - -// for x1, the partial round gates are identical to full round gates -// for x2, the partial round gates are just a linear combination - -// extGate2 applies the external matrix mul, outputting the second element of the result -func extGate2(x ...fr.Element) fr.Element { - x[1]. - Double(&x[1]). - Add(&x[1], &x[0]) - return x[1] -} - -// intGate2 applies the internal matrix mul, returning the second element -func intGate2(x ...fr.Element) fr.Element { - x[0].Add(&x[0], &x[1]) - x[1]. - Double(&x[1]). - Add(&x[1], &x[0]) - return x[1] -} - -// intKeyGate2 applies the second row of internal matrix mul, then adds the round key -func intKeyGate2(roundKey *fr.Element) gkr.GateFunction { - return func(x ...fr.Element) fr.Element { - x[0].Add(&x[0], &x[1]) - x[1]. - Double(&x[1]). - Add(&x[1], &x[0]). - Add(&x[1], roundKey) - - return x[1] - } -} - -// powGate4 x -> x⁴ -func pow4Gate(x ...fr.Element) fr.Element { - x[0].Square(&x[0]).Square(&x[0]) - return x[0] -} - -// pow4TimesGate x,y -> x⁴ * y -func pow4TimesGate(x ...fr.Element) fr.Element { - x[0].Square(&x[0]).Square(&x[0]).Mul(&x[0], &x[1]) - return x[0] -} - -// pow2Gate x -> x² -func pow2Gate(x ...fr.Element) fr.Element { - x[0].Square(&x[0]) - return x[0] -} - -// pow2TimesGate x,y -> x² * y -func pow2TimesGate(x ...fr.Element) fr.Element { - x[0].Square(&x[0]).Mul(&x[0], &x[1]) - return x[0] -} - -const ( - Pow2GateName gkr.GateName = "pow2" - Pow4GateName gkr.GateName = "pow4" - Pow2TimesGateName gkr.GateName = "pow2Times" - Pow4TimesGateName gkr.GateName = "pow4Times" -) - -type roundGateNamer string - -// RoundGateNamer returns an object that returns standardized names for gates in the GKR circuit -func RoundGateNamer(p *poseidon2.Parameters) roundGateNamer { - return roundGateNamer(p.String()) -} - -// Linear is the name of a gate where a polynomial of total degree 1 is applied to the input -func (n roundGateNamer) Linear(varIndex, round int) gkr.GateName { - return gkr.GateName(fmt.Sprintf("x%d-l-op-round=%d;%s", varIndex, round, n)) -} - -// Integrated is the name of a gate where a polynomial of total degree 1 is applied to the input, followed by an S-box -func (n roundGateNamer) Integrated(varIndex, round int) gkr.GateName { - return gkr.GateName(fmt.Sprintf("x%d-i-op-round=%d;%s", varIndex, round, n)) -} - -var initOnce sync.Once - -// RegisterGkrGates registers the Poseidon2 compression gates for GKR -func RegisterGkrGates() error { - const ( - x = iota - y - ) - var err error - initOnce.Do( - func() { - p := poseidon2.GetDefaultParameters() - halfRf := p.NbFullRounds / 2 - gateNames := RoundGateNamer(p) - - if err = gkr.RegisterGate(Pow2GateName, pow2Gate, 1, gkr.WithUnverifiedDegree(2), gkr.WithNoSolvableVar()); err != nil { - return - } - if err = gkr.RegisterGate(Pow4GateName, pow4Gate, 1, gkr.WithUnverifiedDegree(4), gkr.WithNoSolvableVar()); err != nil { - return - } - if err = gkr.RegisterGate(Pow2TimesGateName, pow2TimesGate, 2, gkr.WithUnverifiedDegree(3), gkr.WithNoSolvableVar()); err != nil { - return - } - if err = gkr.RegisterGate(Pow4TimesGateName, pow4TimesGate, 2, gkr.WithUnverifiedDegree(5), gkr.WithNoSolvableVar()); err != nil { - return - } - - extKeySBox := func(round int, varIndex int) error { - if err := gkr.RegisterGate(gateNames.Integrated(varIndex, round), extKeySBoxGate(&p.RoundKeys[round][varIndex]), 2, gkr.WithUnverifiedDegree(poseidon2.DegreeSBox()), gkr.WithNoSolvableVar()); err != nil { - return err - } - - return gkr.RegisterGate(gateNames.Linear(varIndex, round), extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)) - } - - intKeySBox2 := func(round int) error { - if err := gkr.RegisterGate(gateNames.Linear(y, round), intKeyGate2(&p.RoundKeys[round][1]), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { - return err - } - return gkr.RegisterGate(gateNames.Integrated(y, round), intKeySBoxGate2(&p.RoundKeys[round][1]), 2, gkr.WithUnverifiedDegree(poseidon2.DegreeSBox()), gkr.WithNoSolvableVar()) - } - - fullRound := func(i int) error { - if err := extKeySBox(i, x); err != nil { - return err - } - return extKeySBox(i, y) - } - - for i := range halfRf { - if err = fullRound(i); err != nil { - return - } - } - - { // i = halfRf: first partial round - if err = extKeySBox(halfRf, x); err != nil { - return - } - if err = gkr.RegisterGate(gateNames.Linear(y, halfRf), extGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { - return - } - } - - for i := halfRf + 1; i < halfRf+p.NbPartialRounds; i++ { - if err = extKeySBox(i, x); err != nil { // for x1, intKeySBox is identical to extKeySBox - return - } - if err = gkr.RegisterGate(gateNames.Linear(y, i), intGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { - return - } - } - - { - i := halfRf + p.NbPartialRounds - if err = extKeySBox(i, x); err != nil { - return - } - if err = intKeySBox2(i); err != nil { - return - } - } - - for i := halfRf + p.NbPartialRounds + 1; i < p.NbPartialRounds+p.NbFullRounds; i++ { - if err = fullRound(i); err != nil { - return - } - } - - err = gkr.RegisterGate(gateNames.Linear(y, p.NbPartialRounds+p.NbFullRounds), extAddGate, 3, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)) - }, - ) - return err -} diff --git a/ecc/bls12-377/fr/sumcheck/sumcheck.go b/ecc/bls12-377/fr/sumcheck/sumcheck.go deleted file mode 100644 index b3258bfaa..000000000 --- a/ecc/bls12-377/fr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff fr.Element - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]fr.Element, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff fr.Element - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]fr.Element, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/ecc/bls12-377/fr/sumcheck/sumcheck_test.go b/ecc/bls12-377/fr/sumcheck/sumcheck_test.go deleted file mode 100644 index 6634b5b72..000000000 --- a/ecc/bls12-377/fr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []fr.Element{sum} -} - -func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum fr.Element -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/ecc/bls12-377/fr/test_vector_utils/test_vector_utils.go b/ecc/bls12-377/fr/test_vector_utils/test_vector_utils.go deleted file mode 100644 index f6cbca6f6..000000000 --- a/ecc/bls12-377/fr/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - "hash" - "reflect" - "strings" -) - -func ToElement(i int64) *fr.Element { - var res fr.Element - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/fr.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/fr.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res fr.Element - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return fr.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return fr.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []fr.Element - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return fr.Bytes -} - -func (h *ListHash) BlockSize() int { - return fr.Bytes -} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} - -func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { - elementSlice := make([]fr.Element, len(slice)) - for i, v := range slice { - if _, err := SetElement(&elementSlice[i], v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []fr.Element, b []fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *fr.Element) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(fr.Element) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/ecc/bls12-381/fr/gkr/gkr.go b/ecc/bls12-381/fr/gkr/gkr.go deleted file mode 100644 index fd4395813..000000000 --- a/ecc/bls12-381/fr/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...fr.Element) fr.Element - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation fr.Element - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]fr.Element, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step fr.Element - - res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - - //defer the proof, return list of claims - evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []fr.Element{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func topologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := topologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]fr.Element, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]fr.Element, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []fr.Element) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/ecc/bls12-381/fr/gkr/gkr_test.go b/ecc/bls12-381/fr/gkr/gkr_test.go deleted file mode 100644 index 5249eb561..000000000 --- a/ecc/bls12-381/fr/gkr/gkr_test.go +++ /dev/null @@ -1,828 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []fr.Element{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { - return func(t *testing.T, inputAssignments ...[]fr.Element) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []fr.Element{three}, five) - manager.add(wire, []fr.Element{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six fr.Element - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { - fullAssignments := make([][]fr.Element, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]fr.Element, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []fr.Element) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "../../../../internal/generator/gkr/test_vectors" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := topologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []fr.Element(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []fr.Element - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/ecc/bls12-381/fr/gkr/registry.go b/ecc/bls12-381/fr/gkr/registry.go deleted file mode 100644 index 30cc1273f..000000000 --- a/ecc/bls12-381/fr/gkr/registry.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(fr.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]fr.Element, nbIn) - consts := make(fr.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - domain := fft.NewDomain(degreeBound) - // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) - x := fr.One() - for i := range p { - fIn[0] = x - for j := range consts { - fIn[j+1].Mul(&x, &consts[j]) - } - p[i] = f(fIn...) - - x.Mul(&x, &domain.Generator) - } - - // obtain p's coefficients - domain.FFTInverse(p, fft.DIF) - fft.BitReverse(p) - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/ecc/bls12-381/fr/sumcheck/sumcheck.go b/ecc/bls12-381/fr/sumcheck/sumcheck.go deleted file mode 100644 index 234d182ab..000000000 --- a/ecc/bls12-381/fr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff fr.Element - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]fr.Element, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff fr.Element - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]fr.Element, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/ecc/bls12-381/fr/sumcheck/sumcheck_test.go b/ecc/bls12-381/fr/sumcheck/sumcheck_test.go deleted file mode 100644 index 2ec7d9c16..000000000 --- a/ecc/bls12-381/fr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []fr.Element{sum} -} - -func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum fr.Element -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/ecc/bls12-381/fr/test_vector_utils/test_vector_utils.go b/ecc/bls12-381/fr/test_vector_utils/test_vector_utils.go deleted file mode 100644 index 7a1414595..000000000 --- a/ecc/bls12-381/fr/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - "hash" - "reflect" - "strings" -) - -func ToElement(i int64) *fr.Element { - var res fr.Element - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/fr.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/fr.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res fr.Element - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return fr.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return fr.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []fr.Element - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return fr.Bytes -} - -func (h *ListHash) BlockSize() int { - return fr.Bytes -} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} - -func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { - elementSlice := make([]fr.Element, len(slice)) - for i, v := range slice { - if _, err := SetElement(&elementSlice[i], v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []fr.Element, b []fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *fr.Element) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(fr.Element) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/ecc/bls24-315/fr/gkr/gkr.go b/ecc/bls24-315/fr/gkr/gkr.go deleted file mode 100644 index 24ad81d32..000000000 --- a/ecc/bls24-315/fr/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...fr.Element) fr.Element - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation fr.Element - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]fr.Element, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step fr.Element - - res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - - //defer the proof, return list of claims - evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []fr.Element{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func topologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := topologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]fr.Element, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]fr.Element, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []fr.Element) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/ecc/bls24-315/fr/gkr/gkr_test.go b/ecc/bls24-315/fr/gkr/gkr_test.go deleted file mode 100644 index 12b0c0d46..000000000 --- a/ecc/bls24-315/fr/gkr/gkr_test.go +++ /dev/null @@ -1,828 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []fr.Element{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { - return func(t *testing.T, inputAssignments ...[]fr.Element) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []fr.Element{three}, five) - manager.add(wire, []fr.Element{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six fr.Element - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { - fullAssignments := make([][]fr.Element, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]fr.Element, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []fr.Element) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "../../../../internal/generator/gkr/test_vectors" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := topologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []fr.Element(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []fr.Element - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/ecc/bls24-315/fr/gkr/registry.go b/ecc/bls24-315/fr/gkr/registry.go deleted file mode 100644 index eaa05cedd..000000000 --- a/ecc/bls24-315/fr/gkr/registry.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(fr.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]fr.Element, nbIn) - consts := make(fr.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - domain := fft.NewDomain(degreeBound) - // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) - x := fr.One() - for i := range p { - fIn[0] = x - for j := range consts { - fIn[j+1].Mul(&x, &consts[j]) - } - p[i] = f(fIn...) - - x.Mul(&x, &domain.Generator) - } - - // obtain p's coefficients - domain.FFTInverse(p, fft.DIF) - fft.BitReverse(p) - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/ecc/bls24-315/fr/sumcheck/sumcheck.go b/ecc/bls24-315/fr/sumcheck/sumcheck.go deleted file mode 100644 index 8a256bc38..000000000 --- a/ecc/bls24-315/fr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff fr.Element - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]fr.Element, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff fr.Element - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]fr.Element, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/ecc/bls24-315/fr/sumcheck/sumcheck_test.go b/ecc/bls24-315/fr/sumcheck/sumcheck_test.go deleted file mode 100644 index cc64cb58f..000000000 --- a/ecc/bls24-315/fr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []fr.Element{sum} -} - -func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum fr.Element -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/ecc/bls24-315/fr/test_vector_utils/test_vector_utils.go b/ecc/bls24-315/fr/test_vector_utils/test_vector_utils.go deleted file mode 100644 index de3318ae9..000000000 --- a/ecc/bls24-315/fr/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - "hash" - "reflect" - "strings" -) - -func ToElement(i int64) *fr.Element { - var res fr.Element - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/fr.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/fr.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res fr.Element - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return fr.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return fr.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []fr.Element - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return fr.Bytes -} - -func (h *ListHash) BlockSize() int { - return fr.Bytes -} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} - -func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { - elementSlice := make([]fr.Element, len(slice)) - for i, v := range slice { - if _, err := SetElement(&elementSlice[i], v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []fr.Element, b []fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *fr.Element) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(fr.Element) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/ecc/bls24-317/fr/gkr/gkr.go b/ecc/bls24-317/fr/gkr/gkr.go deleted file mode 100644 index 308eec8e3..000000000 --- a/ecc/bls24-317/fr/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...fr.Element) fr.Element - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation fr.Element - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]fr.Element, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step fr.Element - - res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - - //defer the proof, return list of claims - evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []fr.Element{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func topologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := topologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]fr.Element, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]fr.Element, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []fr.Element) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/ecc/bls24-317/fr/gkr/gkr_test.go b/ecc/bls24-317/fr/gkr/gkr_test.go deleted file mode 100644 index 5116784d0..000000000 --- a/ecc/bls24-317/fr/gkr/gkr_test.go +++ /dev/null @@ -1,828 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []fr.Element{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { - return func(t *testing.T, inputAssignments ...[]fr.Element) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []fr.Element{three}, five) - manager.add(wire, []fr.Element{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six fr.Element - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { - fullAssignments := make([][]fr.Element, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]fr.Element, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []fr.Element) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "../../../../internal/generator/gkr/test_vectors" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := topologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []fr.Element(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []fr.Element - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/ecc/bls24-317/fr/gkr/registry.go b/ecc/bls24-317/fr/gkr/registry.go deleted file mode 100644 index 28aad164b..000000000 --- a/ecc/bls24-317/fr/gkr/registry.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(fr.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]fr.Element, nbIn) - consts := make(fr.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - domain := fft.NewDomain(degreeBound) - // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) - x := fr.One() - for i := range p { - fIn[0] = x - for j := range consts { - fIn[j+1].Mul(&x, &consts[j]) - } - p[i] = f(fIn...) - - x.Mul(&x, &domain.Generator) - } - - // obtain p's coefficients - domain.FFTInverse(p, fft.DIF) - fft.BitReverse(p) - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/ecc/bls24-317/fr/sumcheck/sumcheck.go b/ecc/bls24-317/fr/sumcheck/sumcheck.go deleted file mode 100644 index 16e185683..000000000 --- a/ecc/bls24-317/fr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff fr.Element - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]fr.Element, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff fr.Element - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]fr.Element, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/ecc/bls24-317/fr/sumcheck/sumcheck_test.go b/ecc/bls24-317/fr/sumcheck/sumcheck_test.go deleted file mode 100644 index a950b2b52..000000000 --- a/ecc/bls24-317/fr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []fr.Element{sum} -} - -func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum fr.Element -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/ecc/bls24-317/fr/test_vector_utils/test_vector_utils.go b/ecc/bls24-317/fr/test_vector_utils/test_vector_utils.go deleted file mode 100644 index 1f91003ca..000000000 --- a/ecc/bls24-317/fr/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" - "hash" - "reflect" - "strings" -) - -func ToElement(i int64) *fr.Element { - var res fr.Element - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/fr.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/fr.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res fr.Element - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return fr.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return fr.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []fr.Element - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return fr.Bytes -} - -func (h *ListHash) BlockSize() int { - return fr.Bytes -} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} - -func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { - elementSlice := make([]fr.Element, len(slice)) - for i, v := range slice { - if _, err := SetElement(&elementSlice[i], v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []fr.Element, b []fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *fr.Element) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(fr.Element) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/ecc/bn254/fr/gkr/gkr.go b/ecc/bn254/fr/gkr/gkr.go deleted file mode 100644 index cfc28364d..000000000 --- a/ecc/bn254/fr/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...fr.Element) fr.Element - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation fr.Element - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]fr.Element, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step fr.Element - - res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - - //defer the proof, return list of claims - evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []fr.Element{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func topologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := topologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]fr.Element, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]fr.Element, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []fr.Element) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/ecc/bn254/fr/gkr/gkr_test.go b/ecc/bn254/fr/gkr/gkr_test.go deleted file mode 100644 index 7e41436d0..000000000 --- a/ecc/bn254/fr/gkr/gkr_test.go +++ /dev/null @@ -1,828 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []fr.Element{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { - return func(t *testing.T, inputAssignments ...[]fr.Element) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []fr.Element{three}, five) - manager.add(wire, []fr.Element{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six fr.Element - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { - fullAssignments := make([][]fr.Element, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]fr.Element, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []fr.Element) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "../../../../internal/generator/gkr/test_vectors" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := topologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []fr.Element(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []fr.Element - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/ecc/bn254/fr/gkr/registry.go b/ecc/bn254/fr/gkr/registry.go deleted file mode 100644 index 3faeba1c5..000000000 --- a/ecc/bn254/fr/gkr/registry.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(fr.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]fr.Element, nbIn) - consts := make(fr.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - domain := fft.NewDomain(degreeBound) - // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) - x := fr.One() - for i := range p { - fIn[0] = x - for j := range consts { - fIn[j+1].Mul(&x, &consts[j]) - } - p[i] = f(fIn...) - - x.Mul(&x, &domain.Generator) - } - - // obtain p's coefficients - domain.FFTInverse(p, fft.DIF) - fft.BitReverse(p) - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/ecc/bn254/fr/sumcheck/sumcheck.go b/ecc/bn254/fr/sumcheck/sumcheck.go deleted file mode 100644 index 89b7d9b8e..000000000 --- a/ecc/bn254/fr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff fr.Element - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]fr.Element, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff fr.Element - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]fr.Element, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/ecc/bn254/fr/sumcheck/sumcheck_test.go b/ecc/bn254/fr/sumcheck/sumcheck_test.go deleted file mode 100644 index 563ca3d28..000000000 --- a/ecc/bn254/fr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []fr.Element{sum} -} - -func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum fr.Element -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/ecc/bn254/fr/test_vector_utils/test_vector_utils.go b/ecc/bn254/fr/test_vector_utils/test_vector_utils.go deleted file mode 100644 index be73caa59..000000000 --- a/ecc/bn254/fr/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "hash" - "reflect" - "strings" -) - -func ToElement(i int64) *fr.Element { - var res fr.Element - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/fr.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/fr.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res fr.Element - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return fr.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return fr.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []fr.Element - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return fr.Bytes -} - -func (h *ListHash) BlockSize() int { - return fr.Bytes -} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} - -func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { - elementSlice := make([]fr.Element, len(slice)) - for i, v := range slice { - if _, err := SetElement(&elementSlice[i], v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []fr.Element, b []fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *fr.Element) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(fr.Element) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/ecc/bn254/fr/test_vector_utils/test_vector_utils_test.go b/ecc/bn254/fr/test_vector_utils/test_vector_utils_test.go deleted file mode 100644 index 261b27686..000000000 --- a/ecc/bn254/fr/test_vector_utils/test_vector_utils_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package test_vector_utils - -import ( - "testing" - - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" - "github.com/stretchr/testify/assert" -) - -func TestCounterTranscriptInequality(t *testing.T) { - const challengeName = "fC.0" - t1 := fiatshamir.NewTranscript(test_vector_utils.NewMessageCounter(1, 1), challengeName) - t2 := fiatshamir.NewTranscript(test_vector_utils.NewMessageCounter(0, 1), challengeName) - var c1, c2 []byte - var err error - c1, err = t1.ComputeChallenge(challengeName) - assert.NoError(t, err) - c2, err = t2.ComputeChallenge(challengeName) - assert.NoError(t, err) - assert.NotEqual(t, c1, c2) -} diff --git a/ecc/bw6-633/fr/gkr/gkr.go b/ecc/bw6-633/fr/gkr/gkr.go deleted file mode 100644 index 221b985c5..000000000 --- a/ecc/bw6-633/fr/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...fr.Element) fr.Element - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation fr.Element - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]fr.Element, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step fr.Element - - res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - - //defer the proof, return list of claims - evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []fr.Element{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func topologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := topologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]fr.Element, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]fr.Element, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []fr.Element) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/ecc/bw6-633/fr/gkr/gkr_test.go b/ecc/bw6-633/fr/gkr/gkr_test.go deleted file mode 100644 index ceda23276..000000000 --- a/ecc/bw6-633/fr/gkr/gkr_test.go +++ /dev/null @@ -1,828 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []fr.Element{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { - return func(t *testing.T, inputAssignments ...[]fr.Element) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []fr.Element{three}, five) - manager.add(wire, []fr.Element{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six fr.Element - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { - fullAssignments := make([][]fr.Element, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]fr.Element, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []fr.Element) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "../../../../internal/generator/gkr/test_vectors" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := topologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []fr.Element(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []fr.Element - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/ecc/bw6-633/fr/gkr/registry.go b/ecc/bw6-633/fr/gkr/registry.go deleted file mode 100644 index dfe6d2f45..000000000 --- a/ecc/bw6-633/fr/gkr/registry.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(fr.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]fr.Element, nbIn) - consts := make(fr.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - domain := fft.NewDomain(degreeBound) - // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) - x := fr.One() - for i := range p { - fIn[0] = x - for j := range consts { - fIn[j+1].Mul(&x, &consts[j]) - } - p[i] = f(fIn...) - - x.Mul(&x, &domain.Generator) - } - - // obtain p's coefficients - domain.FFTInverse(p, fft.DIF) - fft.BitReverse(p) - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/ecc/bw6-633/fr/sumcheck/sumcheck.go b/ecc/bw6-633/fr/sumcheck/sumcheck.go deleted file mode 100644 index 46a4f1dfc..000000000 --- a/ecc/bw6-633/fr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff fr.Element - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]fr.Element, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff fr.Element - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]fr.Element, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/ecc/bw6-633/fr/sumcheck/sumcheck_test.go b/ecc/bw6-633/fr/sumcheck/sumcheck_test.go deleted file mode 100644 index b214ea266..000000000 --- a/ecc/bw6-633/fr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []fr.Element{sum} -} - -func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum fr.Element -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/ecc/bw6-633/fr/test_vector_utils/test_vector_utils.go b/ecc/bw6-633/fr/test_vector_utils/test_vector_utils.go deleted file mode 100644 index 11599a323..000000000 --- a/ecc/bw6-633/fr/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - "hash" - "reflect" - "strings" -) - -func ToElement(i int64) *fr.Element { - var res fr.Element - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/fr.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/fr.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res fr.Element - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return fr.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return fr.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []fr.Element - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return fr.Bytes -} - -func (h *ListHash) BlockSize() int { - return fr.Bytes -} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} - -func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { - elementSlice := make([]fr.Element, len(slice)) - for i, v := range slice { - if _, err := SetElement(&elementSlice[i], v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []fr.Element, b []fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *fr.Element) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(fr.Element) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/ecc/bw6-761/fr/gkr/gkr.go b/ecc/bw6-761/fr/gkr/gkr.go deleted file mode 100644 index c188d25aa..000000000 --- a/ecc/bw6-761/fr/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...fr.Element) fr.Element - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation fr.Element - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]fr.Element, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step fr.Element - - res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - - //defer the proof, return list of claims - evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []fr.Element{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func topologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := topologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]fr.Element, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]fr.Element, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []fr.Element) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/ecc/bw6-761/fr/gkr/gkr_test.go b/ecc/bw6-761/fr/gkr/gkr_test.go deleted file mode 100644 index da68edf07..000000000 --- a/ecc/bw6-761/fr/gkr/gkr_test.go +++ /dev/null @@ -1,828 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []fr.Element{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { - return func(t *testing.T, inputAssignments ...[]fr.Element) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []fr.Element{three}, five) - manager.add(wire, []fr.Element{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six fr.Element - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { - fullAssignments := make([][]fr.Element, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]fr.Element, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []fr.Element) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "../../../../internal/generator/gkr/test_vectors" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := topologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []fr.Element(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []fr.Element - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/ecc/bw6-761/fr/gkr/registry.go b/ecc/bw6-761/fr/gkr/registry.go deleted file mode 100644 index fa5e2f605..000000000 --- a/ecc/bw6-761/fr/gkr/registry.go +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(fr.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]fr.Element, nbIn) - consts := make(fr.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - domain := fft.NewDomain(degreeBound) - // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) - x := fr.One() - for i := range p { - fIn[0] = x - for j := range consts { - fIn[j+1].Mul(&x, &consts[j]) - } - p[i] = f(fIn...) - - x.Mul(&x, &domain.Generator) - } - - // obtain p's coefficients - domain.FFTInverse(p, fft.DIF) - fft.BitReverse(p) - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/ecc/bw6-761/fr/sumcheck/sumcheck.go b/ecc/bw6-761/fr/sumcheck/sumcheck.go deleted file mode 100644 index 779fbe7fe..000000000 --- a/ecc/bw6-761/fr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff fr.Element - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]fr.Element, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff fr.Element - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]fr.Element, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/ecc/bw6-761/fr/sumcheck/sumcheck_test.go b/ecc/bw6-761/fr/sumcheck/sumcheck_test.go deleted file mode 100644 index 875d82578..000000000 --- a/ecc/bw6-761/fr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []fr.Element{sum} -} - -func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum fr.Element -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/ecc/bw6-761/fr/test_vector_utils/test_vector_utils.go b/ecc/bw6-761/fr/test_vector_utils/test_vector_utils.go deleted file mode 100644 index 26530db60..000000000 --- a/ecc/bw6-761/fr/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - "hash" - "reflect" - "strings" -) - -func ToElement(i int64) *fr.Element { - var res fr.Element - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/fr.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/fr.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res fr.Element - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return fr.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return fr.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []fr.Element - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return fr.Bytes -} - -func (h *ListHash) BlockSize() int { - return fr.Bytes -} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} - -func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { - elementSlice := make([]fr.Element, len(slice)) - for i, v := range slice { - if _, err := SetElement(&elementSlice[i], v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []fr.Element, b []fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *fr.Element) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(fr.Element) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/ecc/grumpkin/fr/gkr/gkr.go b/ecc/grumpkin/fr/gkr/gkr.go deleted file mode 100644 index d8868189d..000000000 --- a/ecc/grumpkin/fr/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...fr.Element) fr.Element - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation fr.Element - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]fr.Element, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step fr.Element - - res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - - //defer the proof, return list of claims - evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { - res := make([]fr.Element, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []fr.Element{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []fr.Element - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func topologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := topologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]fr.Element, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]fr.Element, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []fr.Element) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/ecc/grumpkin/fr/gkr/gkr_test.go b/ecc/grumpkin/fr/gkr/gkr_test.go deleted file mode 100644 index d9692f2f5..000000000 --- a/ecc/grumpkin/fr/gkr/gkr_test.go +++ /dev/null @@ -1,828 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/mimc" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []fr.Element{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) -} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { - return func(t *testing.T, inputAssignments ...[]fr.Element) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []fr.Element{three}, five) - manager.add(wire, []fr.Element{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six fr.Element - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { - fullAssignments := make([][]fr.Element, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]fr.Element, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []fr.Element) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "../../../../internal/generator/gkr/test_vectors" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := topologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := topologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []fr.Element(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []fr.Element - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/ecc/grumpkin/fr/gkr/registry.go b/ecc/grumpkin/fr/gkr/registry.go deleted file mode 100644 index d142823bb..000000000 --- a/ecc/grumpkin/fr/gkr/registry.go +++ /dev/null @@ -1,374 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(fr.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]fr.Element, nbIn) - consts := make(fr.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - x := make(fr.Vector, degreeBound) - x.MustSetRandom() - for i := range x { - fIn[0] = x[i] - for j := range consts { - fIn[j+1].Mul(&x[i], &consts[j]) - } - p[i] = f(fIn...) - } - - // obtain p's coefficients - p, err := interpolate(x, p) - if err != nil { - panic(err) - } - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) -// Note that the runtime is O(len(X)³) -func interpolate(X, Y []fr.Element) (polynomial.Polynomial, error) { - if len(X) != len(Y) { - return nil, errors.New("X and Y must have the same length") - } - - // solve the system of equations by Gaussian elimination - augmentedRows := make([][]fr.Element, len(X)) // the last column is the Y values - for i := range augmentedRows { - augmentedRows[i] = make([]fr.Element, len(X)+1) - augmentedRows[i][0].SetOne() - augmentedRows[i][1].Set(&X[i]) - for j := 2; j < len(augmentedRows[i])-1; j++ { - augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) - } - augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) - } - - // make the upper triangle - for i := range len(augmentedRows) - 1 { - // use row i to eliminate the ith element in all rows below - var negInv fr.Element - if augmentedRows[i][i].IsZero() { - return nil, errors.New("singular matrix") - } - negInv.Inverse(&augmentedRows[i][i]) - negInv.Neg(&negInv) - for j := i + 1; j < len(augmentedRows); j++ { - var c fr.Element - c.Mul(&augmentedRows[j][i], &negInv) - // augmentedRows[j][i].SetZero() omitted - for k := i + 1; k < len(augmentedRows[i]); k++ { - var t fr.Element - t.Mul(&augmentedRows[i][k], &c) - augmentedRows[j][k].Add(&augmentedRows[j][k], &t) - } - } - } - - // back substitution - res := make(polynomial.Polynomial, len(X)) - for i := len(augmentedRows) - 1; i >= 0; i-- { - res[i] = augmentedRows[i][len(augmentedRows[i])-1] - for j := i + 1; j < len(augmentedRows[i])-1; j++ { - var t fr.Element - t.Mul(&res[j], &augmentedRows[i][j]) - res[i].Sub(&res[i], &t) - } - res[i].Div(&res[i], &augmentedRows[i][i]) - } - - return res, nil -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/ecc/grumpkin/fr/sumcheck/sumcheck.go b/ecc/grumpkin/fr/sumcheck/sumcheck.go deleted file mode 100644 index e901d8479..000000000 --- a/ecc/grumpkin/fr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return fr.Element{}, err - } - } - var res fr.Element - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff fr.Element - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]fr.Element, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff fr.Element - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]fr.Element, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/ecc/grumpkin/fr/sumcheck/sumcheck_test.go b/ecc/grumpkin/fr/sumcheck/sumcheck_test.go deleted file mode 100644 index e7cc10581..000000000 --- a/ecc/grumpkin/fr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []fr.Element{sum} -} - -func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum fr.Element -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/ecc/grumpkin/fr/test_vector_utils/test_vector_utils.go b/ecc/grumpkin/fr/test_vector_utils/test_vector_utils.go deleted file mode 100644 index df83ecc9b..000000000 --- a/ecc/grumpkin/fr/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr" - "github.com/consensys/gnark-crypto/ecc/grumpkin/fr/polynomial" - "hash" - "reflect" - "strings" -) - -func ToElement(i int64) *fr.Element { - var res fr.Element - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/fr.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/fr.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res fr.Element - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return fr.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return fr.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []fr.Element - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return fr.Bytes -} - -func (h *ListHash) BlockSize() int { - return fr.Bytes -} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} - -func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { - elementSlice := make([]fr.Element, len(slice)) - for i, v := range slice { - if _, err := SetElement(&elementSlice[i], v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []fr.Element, b []fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *fr.Element) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(fr.Element) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/field/generator/config/field_config.go b/field/generator/config/field_config.go index a8485cb8e..b06c0063b 100644 --- a/field/generator/config/field_config.go +++ b/field/generator/config/field_config.go @@ -491,3 +491,9 @@ func (f *Field) WriteElement(element Element) string { builder.WriteString("}") return builder.String() } + +type FieldDependency struct { + ElementType string + FieldPackagePath string + FieldPackageName string +} diff --git a/internal/generator/config/curve.go b/internal/generator/config/curve.go index c490a1143..12e6c6f77 100644 --- a/internal/generator/config/curve.go +++ b/internal/generator/config/curve.go @@ -94,9 +94,3 @@ func newFieldInfo(modulus string) Field { F.Modulus = func() *big.Int { return new(big.Int).Set(&bModulus) } return F } - -type FieldDependency struct { - FieldPackagePath string - ElementType string - FieldPackageName string -} diff --git a/internal/generator/gkr/generate.go b/internal/generator/gkr/generate.go deleted file mode 100644 index fb9211200..000000000 --- a/internal/generator/gkr/generate.go +++ /dev/null @@ -1,31 +0,0 @@ -package gkr - -import ( - "path/filepath" - - "github.com/consensys/bavard" - "github.com/consensys/gnark-crypto/internal/generator/config" -) - -type Config struct { - config.FieldDependency - GenerateTests bool - RetainTestCaseRawInfo bool - CanUseFFT bool - OutsideGkrPackage bool - TestVectorsRelativePath string -} - -func Generate(config Config, baseDir string, bgen *bavard.BatchGenerator) error { - entries := []bavard.Entry{ - {File: filepath.Join(baseDir, "gkr.go"), Templates: []string{"gkr.go.tmpl"}}, - {File: filepath.Join(baseDir, "registry.go"), Templates: []string{"registry.go.tmpl"}}, - } - - if config.GenerateTests { - entries = append(entries, - bavard.Entry{File: filepath.Join(baseDir, "gkr_test.go"), Templates: []string{"gkr.test.go.tmpl", "gkr.test.vectors.go.tmpl"}}) - } - - return bgen.Generate(config, "gkr", "./gkr/template/", entries...) -} diff --git a/internal/generator/gkr/template/gkr.go.tmpl b/internal/generator/gkr/template/gkr.go.tmpl deleted file mode 100644 index c27daa9b5..000000000 --- a/internal/generator/gkr/template/gkr.go.tmpl +++ /dev/null @@ -1,863 +0,0 @@ -import ( - "errors" - "fmt" - "{{.FieldPackagePath}}" - "{{.FieldPackagePath}}/polynomial" - "{{.FieldPackagePath}}/sumcheck" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -{{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...{{.ElementType}}) {{.ElementType}} - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]{{.ElementType}} - claimedEvaluations []{{.ElementType}} - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a {{.ElementType}}) {{.ElementType}} { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]{{.ElementType}}) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation {{.ElementType}} - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]{{.ElementType}}, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]{{.ElementType}} // x in the paper - claimedEvaluations []{{.ElementType}} // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff {{.ElementType}}) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq,c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.ElementType}}) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]{{.ElementType}}, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step {{.ElementType}} - - res := make([]{{.ElementType}}, degGJ) - operands := make([]{{.ElementType}}, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element {{.ElementType}}) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) interface{} { - - //defer the proof, return list of claims - evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]{{.ElementType}}, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []{{.ElementType}}, evaluation {{.ElementType}}) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func (options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]{{.ElementType}}, error) { - res := make([]{{.ElementType}}, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []{{.ElementType}} - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []{{.ElementType}}{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]{{.ElementType}}) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []{{.ElementType}} - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]{{.ElementType}}) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// {{$topologicalSort}} sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func {{$topologicalSort}}(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := {{$topologicalSort}}(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]{{.ElementType}}, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]{{.ElementType}}, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]{{.ElementType}}) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []{{.ElementType}}) { - for i := range src { - src[i].BigInt(dst[i]) - } -} \ No newline at end of file diff --git a/internal/generator/gkr/template/gkr.test.go.tmpl b/internal/generator/gkr/template/gkr.test.go.tmpl deleted file mode 100644 index 378cb813e..000000000 --- a/internal/generator/gkr/template/gkr.test.go.tmpl +++ /dev/null @@ -1,611 +0,0 @@ - -import ( - "{{.FieldPackagePath}}" - "{{.FieldPackagePath}}/mimc" - "{{.FieldPackagePath}}/polynomial" - "{{.FieldPackagePath}}/sumcheck" - "{{.FieldPackagePath}}/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/stretchr/testify/assert" - "fmt" - "hash" - "os" - "strconv" - "testing" - "path/filepath" - "encoding/json" - "reflect" - "time" -) - -{{$GenerateLargeTests := .GenerateTests}} {{/* this is redundant. soon to be removed if a use case for it doesn't come back */}} -{{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []{{.ElementType}}{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []{{.ElementType}}{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []{{.ElementType}}{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) -} - -{{- if $GenerateLargeTests}} -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]{{.ElementType}}) { - return func(t *testing.T, inputAssignments ...[]{{.ElementType}}) { - testMimc(t, numRounds, inputAssignments...) - } -} - -{{- end}} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{ Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - } } - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []{{.ElementType}}{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []{{.ElementType}}{three}, five) - manager.add(wire, []{{.ElementType}}{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six {{.ElementType}} - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]{{.ElementType}})) { - fullAssignments := make([][]{{.ElementType}}, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]{{.ElementType}}, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t,err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]{{.ElementType}}) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]{{.ElementType}}) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []{{.ElementType}}) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "{{.TestVectorsRelativePath}}" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := {{$topologicalSort}}(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := {{$topologicalSort}}(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := {{$topologicalSort}}(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -{{template "gkrTestVectors" .}} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} \ No newline at end of file diff --git a/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl b/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl deleted file mode 100644 index 832188f3d..000000000 --- a/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl +++ /dev/null @@ -1,123 +0,0 @@ -import ( - "encoding/json" - "fmt" - "hash" - "os" - "path/filepath" - "reflect" - - "github.com/consensys/bavard" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" - -) - -func main() { - if err := GenerateVectors(); err != nil { - fmt.Println(err.Error()) - os.Exit(-1) - } -} - -func GenerateVectors() error { - testDirPath, err := filepath.Abs("gkr/test_vectors") - if err != nil { - return err - } - - fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) - - dirEntries, err := os.ReadDir(testDirPath) - if err != nil { - return err - } - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - if !bavard.ShouldGenerate(path) { - continue - } - fmt.Println("\tprocessing", dirEntry.Name()) - if err = run(path); err != nil { - return err - } - } - } - } - - return nil -} - -func run(absPath string) error { - testCase, err := newTestCase(absPath) - if err != nil { - return err - } - - transcriptSetting := fiatshamir.WithHash(testCase.Hash) - - var proof gkr.Proof - proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) - if err != nil { - return err - } - - if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { - return err - } - var outBytes []byte - if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { - if err = os.WriteFile(absPath, outBytes, 0); err != nil { - return err - } - } else { - return err - } - - testCase, err = newTestCase(absPath) - if err != nil { - return err - } - - err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) - if err != nil { - return err - } - - testCase, err = newTestCase(absPath) - if err != nil { - return err - } - - err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - if err == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { - res := make(PrintableProof, len(proof)) - - for i := range proof { - - partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) - for k, partialK := range proof[i].PartialSumPolys { - partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) - } - - res[i] = PrintableSumcheckProof{ - FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), - PartialSumPolys: partialSumPolys, - } - } - return res, nil -} - -{{template "gkrTestVectors" .}} \ No newline at end of file diff --git a/internal/generator/gkr/template/gkr.test.vectors.go.tmpl b/internal/generator/gkr/template/gkr.test.vectors.go.tmpl deleted file mode 100644 index 0025b0164..000000000 --- a/internal/generator/gkr/template/gkr.test.vectors.go.tmpl +++ /dev/null @@ -1,254 +0,0 @@ -{{define "gkrTestVectors"}} - -{{$GkrPackagePrefix := select .OutsideGkrPackage "" "gkr."}} -{{$CheckOutputCorrectness := not .OutsideGkrPackage}} - -{{$Circuit := print $GkrPackagePrefix "Circuit"}} -{{$Gate := print $GkrPackagePrefix "Gate"}} -{{$Proof := print $GkrPackagePrefix "Proof"}} -{{$WireAssignment := print $GkrPackagePrefix "WireAssignment"}} -{{$Wire := print $GkrPackagePrefix "Wire"}} -{{$CircuitLayer := print $GkrPackagePrefix "CircuitLayer"}} - -{{$PackagePrefix := ""}} -{{- if .OutsideGkrPackage}} - {{$PackagePrefix = "gkr."}} -{{end}} - -type WireInfo struct { - Gate {{$PackagePrefix}}GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]{{$Circuit}}) - -func getCircuit(path string) ({{$Circuit}}, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit {{$Circuit}}) { - circuit = make({{$Circuit}}, len(c)) - for i := range c { - circuit[i].Gate = {{$PackagePrefix}}GetGate(c[i].Gate) - circuit[i].Inputs = make([]*{{$Wire}}, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...{{.ElementType}}) (res {{.ElementType}}) { - var sum {{.ElementType}} - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC {{$PackagePrefix}}GateName = "mimc" - SelectInput3 {{$PackagePrefix}}GateName = "select-input-3" -) - -func init() { - if err := {{$PackagePrefix}}RegisterGate(MiMC, mimcRound, 2, {{$PackagePrefix}}WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := {{$PackagePrefix}}RegisterGate(SelectInput3, func(input ...{{.ElementType}}) {{.ElementType}} { - return input[2] - }, 3, {{$PackagePrefix}}WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) ({{$Proof}}, error) { - proof := make({{$Proof}}, len(printable)) - for i := range printable { - finalEvalProof := []{{.ElementType}}(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]{{.ElementType}}, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := {{ setElement "finalEvalProof[k]" "finalEvalSlice.Index(k).Interface()" .ElementType}}; err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit {{$Circuit}} - Hash hash.Hash - Proof {{$Proof}} - FullAssignment {{$WireAssignment}} - InOutAssignment {{$WireAssignment}} - {{if .RetainTestCaseRawInfo}}Info TestCaseInfo{{end}} -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit {{$Circuit}} - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof {{$Proof}} - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make({{$WireAssignment}}) - inOutAssignment := make({{$WireAssignment}}) - - sorted := {{select .OutsideGkrPackage "t" "gkr.T"}}opologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []{{.ElementType}} - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - {{if not $CheckOutputCorrectness}} - info.Output = make([][]interface{}, 0, outI) - {{end}} - - for _, w := range sorted { - if w.IsOutput() { - {{if $CheckOutputCorrectness}} - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - {{else}} - info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) - {{end}} - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - {{if .RetainTestCaseRawInfo }}Info: info,{{end}} - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -{{end}} - -{{- define "setElement element value elementType"}} -{{- if eq .elementType "fr.Element"}} test_vector_utils.SetElement(&{{.element}}, {{.value}}) -{{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}}) -{{- else}} -{{print "\"UNEXPECTED TYPE" .elementType "\""}} -{{- end}} -{{- end}} \ No newline at end of file diff --git a/internal/generator/gkr/template/registry.go.tmpl b/internal/generator/gkr/template/registry.go.tmpl deleted file mode 100644 index 75ca8d026..000000000 --- a/internal/generator/gkr/template/registry.go.tmpl +++ /dev/null @@ -1,390 +0,0 @@ -import ( - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "{{.FieldPackagePath}}" - {{- if .CanUseFFT }} - "{{.FieldPackagePath}}/fft"{{- else}} - "errors"{{- end }} - "{{.FieldPackagePath}}/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make({{.FieldPackageName}}.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]{{.ElementType}}, nbIn) - consts := make({{.FieldPackageName}}.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - {{- if .CanUseFFT }} - domain := fft.NewDomain(degreeBound) - // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) - x := {{.FieldPackageName}}.One() - for i := range p { - fIn[0] = x - for j := range consts { - fIn[j+1].Mul(&x, &consts[j]) - } - p[i] = f(fIn...) - - x.Mul(&x, &domain.Generator) - } - - // obtain p's coefficients - domain.FFTInverse(p, fft.DIF) - fft.BitReverse(p) - {{- else }} - x := make({{.FieldPackageName}}.Vector, degreeBound) - x.MustSetRandom() - for i := range x { - fIn[0] = x[i] - for j := range consts { - fIn[j+1].Mul(&x[i], &consts[j]) - } - p[i] = f(fIn...) - } - - // obtain p's coefficients - p, err := interpolate(x, p) - if err != nil { - panic(err) - } - {{- end }} - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max)+1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -{{- if not .CanUseFFT }} -// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) -// Note that the runtime is O(len(X)³) -func interpolate(X, Y []{{.ElementType}}) (polynomial.Polynomial, error) { - if len(X) != len(Y) { - return nil, errors.New("X and Y must have the same length") - } - - // solve the system of equations by Gaussian elimination - augmentedRows := make([][]{{.ElementType}}, len(X)) // the last column is the Y values - for i := range augmentedRows { - augmentedRows[i] = make([]{{.ElementType}}, len(X)+1) - augmentedRows[i][0].SetOne() - augmentedRows[i][1].Set(&X[i]) - for j := 2; j < len(augmentedRows[i])-1; j++ { - augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) - } - augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) - } - - // make the upper triangle - for i := range len(augmentedRows) - 1 { - // use row i to eliminate the ith element in all rows below - var negInv {{.ElementType}} - if augmentedRows[i][i].IsZero() { - return nil, errors.New("singular matrix") - } - negInv.Inverse(&augmentedRows[i][i]) - negInv.Neg(&negInv) - for j := i + 1; j < len(augmentedRows); j++ { - var c {{.ElementType}} - c.Mul(&augmentedRows[j][i], &negInv) - // augmentedRows[j][i].SetZero() omitted - for k := i + 1; k < len(augmentedRows[i]); k++ { - var t {{.ElementType}} - t.Mul(&augmentedRows[i][k], &c) - augmentedRows[j][k].Add(&augmentedRows[j][k], &t) - } - } - } - - // back substitution - res := make(polynomial.Polynomial, len(X)) - for i := len(augmentedRows) - 1; i >= 0; i-- { - res[i] = augmentedRows[i][len(augmentedRows[i])-1] - for j := i + 1; j < len(augmentedRows[i])-1; j++ { - var t {{.ElementType}} - t.Mul(&res[j], &augmentedRows[i][j]) - res[i].Sub(&res[i], &t) - } - res[i].Div(&res[i], &augmentedRows[i][i]) - } - - return res, nil -} -{{- end }} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...{{.ElementType}}) {{.ElementType}} { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...{{.ElementType}}) {{.ElementType}} { - var res {{.ElementType}} - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...{{.ElementType}}) {{.ElementType}} { - var res {{.ElementType}} - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...{{.ElementType}}) {{.ElementType}} { - var res {{.ElementType}} - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...{{.ElementType}}) {{.ElementType}} { - var res {{.ElementType}} - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/main.go b/internal/generator/gkr/test_vectors/main.go deleted file mode 100644 index 0bb86739a..000000000 --- a/internal/generator/gkr/test_vectors/main.go +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package main - -import ( - "encoding/json" - "fmt" - "hash" - "os" - "path/filepath" - "reflect" - - "github.com/consensys/bavard" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" -) - -func main() { - if err := GenerateVectors(); err != nil { - fmt.Println(err.Error()) - os.Exit(-1) - } -} - -func GenerateVectors() error { - testDirPath, err := filepath.Abs("gkr/test_vectors") - if err != nil { - return err - } - - fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) - - dirEntries, err := os.ReadDir(testDirPath) - if err != nil { - return err - } - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - if !bavard.ShouldGenerate(path) { - continue - } - fmt.Println("\tprocessing", dirEntry.Name()) - if err = run(path); err != nil { - return err - } - } - } - } - - return nil -} - -func run(absPath string) error { - testCase, err := newTestCase(absPath) - if err != nil { - return err - } - - transcriptSetting := fiatshamir.WithHash(testCase.Hash) - - var proof gkr.Proof - proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) - if err != nil { - return err - } - - if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { - return err - } - var outBytes []byte - if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { - if err = os.WriteFile(absPath, outBytes, 0); err != nil { - return err - } - } else { - return err - } - - testCase, err = newTestCase(absPath) - if err != nil { - return err - } - - err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) - if err != nil { - return err - } - - testCase, err = newTestCase(absPath) - if err != nil { - return err - } - - err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - if err == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { - res := make(PrintableProof, len(proof)) - - for i := range proof { - - partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) - for k, partialK := range proof[i].PartialSumPolys { - partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) - } - - res[i] = PrintableSumcheckProof{ - FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), - PartialSumPolys: partialSumPolys, - } - } - return res, nil -} - -type WireInfo struct { - Gate gkr.GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]gkr.Circuit) - -func getCircuit(path string) (gkr.Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit gkr.Circuit) { - circuit = make(gkr.Circuit, len(c)) - for i := range c { - circuit[i].Gate = gkr.GetGate(c[i].Gate) - circuit[i].Inputs = make([]*gkr.Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { - var sum small_rational.SmallRational - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC gkr.GateName = "mimc" - SelectInput3 gkr.GateName = "select-input-3" -) - -func init() { - if err := gkr.RegisterGate(MiMC, mimcRound, 2, gkr.WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := gkr.RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { - return input[2] - }, 3, gkr.WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (gkr.Proof, error) { - proof := make(gkr.Proof, len(printable)) - for i := range printable { - finalEvalProof := []small_rational.SmallRational(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit gkr.Circuit - Hash hash.Hash - Proof gkr.Proof - FullAssignment gkr.WireAssignment - InOutAssignment gkr.WireAssignment - Info TestCaseInfo -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit gkr.Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof gkr.Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(gkr.WireAssignment) - inOutAssignment := make(gkr.WireAssignment) - - sorted := gkr.TopologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []small_rational.SmallRational - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - info.Output = make([][]interface{}, 0, outI) - - for _, w := range sorted { - if w.IsOutput() { - - info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - Info: info, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} diff --git a/internal/generator/gkr/test_vectors/mimc_five_levels_two_instances._json b/internal/generator/gkr/test_vectors/mimc_five_levels_two_instances._json deleted file mode 100644 index 446d23fdb..000000000 --- a/internal/generator/gkr/test_vectors/mimc_five_levels_two_instances._json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "hash": {"type": "const", "val": -1}, - "circuit": "resources/mimc_five_levels.json", - "input": [[1, 3], [1, 3], [1, 3], [1, 3], [1, 3], [1, 3]], - "output": [[4, 3]], - "proof": [[{"partialSumPolys":[[3,4]],"finalEvalProof":[3]}],[{"partialSumPolys":null,"finalEvalProof":null}]] -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/mimc_five_levels.json b/internal/generator/gkr/test_vectors/resources/mimc_five_levels.json deleted file mode 100644 index 3dd74f42b..000000000 --- a/internal/generator/gkr/test_vectors/resources/mimc_five_levels.json +++ /dev/null @@ -1,36 +0,0 @@ -[ - [ - { - "gate": "mimc", - "inputs": [[1,0], [5,5]] - } - ], - [ - { - "gate": "mimc", - "inputs": [[2,0], [5,4]] - } - ], - [ - { - "gate": "mimc", - "inputs": [[3,0], [5,3]] - } - ], - [ - { - "gate": "mimc", - "inputs": [[4,0], [5,2]] - } - ], - [ - { - "gate": "mimc", - "inputs": [[5,0], [5,1]] - } - ], - [ - {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, - {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []} - ] -] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_identity_gate.json b/internal/generator/gkr/test_vectors/resources/single_identity_gate.json deleted file mode 100644 index a44066c7b..000000000 --- a/internal/generator/gkr/test_vectors/resources/single_identity_gate.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "gate": null, - "inputs": [] - }, - { - "gate": "identity", - "inputs": [0] - } -] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_input_two_identity_gates.json b/internal/generator/gkr/test_vectors/resources/single_input_two_identity_gates.json deleted file mode 100644 index 6181784fa..000000000 --- a/internal/generator/gkr/test_vectors/resources/single_input_two_identity_gates.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - { - "gate": null, - "inputs": [] - }, - { - "gate": "identity", - "inputs": [0] - }, - { - "gate": "identity", - "inputs": [0] - } -] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_input_two_outs.json b/internal/generator/gkr/test_vectors/resources/single_input_two_outs.json deleted file mode 100644 index 3a39e5625..000000000 --- a/internal/generator/gkr/test_vectors/resources/single_input_two_outs.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - { - "gate": null, - "inputs": [] - }, - { - "gate": "mul2", - "inputs": [0, 0] - }, - { - "gate": "identity", - "inputs": [0] - } -] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_mimc_gate.json b/internal/generator/gkr/test_vectors/resources/single_mimc_gate.json deleted file mode 100644 index c89e7d52a..000000000 --- a/internal/generator/gkr/test_vectors/resources/single_mimc_gate.json +++ /dev/null @@ -1,7 +0,0 @@ -[ - {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, - { - "gate": "mimc", - "inputs": [0, 1] - } -] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_mul_gate.json b/internal/generator/gkr/test_vectors/resources/single_mul_gate.json deleted file mode 100644 index d009ebe03..000000000 --- a/internal/generator/gkr/test_vectors/resources/single_mul_gate.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - { - "gate": null, - "inputs": [] - }, - { - "gate": null, - "inputs": [] - }, - { - "gate": "mul2", - "inputs": [0, 1] - } -] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json b/internal/generator/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json deleted file mode 100644 index 26681c2f8..000000000 --- a/internal/generator/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - { - "gate": null, - "inputs": [] - }, - { - "gate": "identity", - "inputs": [0] - }, - { - "gate": "identity", - "inputs": [1] - } -] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json b/internal/generator/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json deleted file mode 100644 index cdbdb3b47..000000000 --- a/internal/generator/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - { - "gate": null, - "inputs": [] - }, - { - "gate": null, - "inputs": [] - }, - { - "gate": "select-input-3", - "inputs": [0,0,1] - } -] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_identity_gate_two_instances.json b/internal/generator/gkr/test_vectors/single_identity_gate_two_instances.json deleted file mode 100644 index ce326d0a6..000000000 --- a/internal/generator/gkr/test_vectors/single_identity_gate_two_instances.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "hash": { - "type": "const", - "val": -1 - }, - "circuit": "resources/single_identity_gate.json", - "input": [ - [ - 4, - 3 - ] - ], - "output": [ - [ - 4, - 3 - ] - ], - "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [ - 5 - ], - "partialSumPolys": [ - [ - -3, - -8 - ] - ] - } - ] -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/internal/generator/gkr/test_vectors/single_input_two_identity_gates_two_instances.json deleted file mode 100644 index 2c95f044f..000000000 --- a/internal/generator/gkr/test_vectors/single_input_two_identity_gates_two_instances.json +++ /dev/null @@ -1,56 +0,0 @@ -{ - "hash": { - "type": "const", - "val": -1 - }, - "circuit": "resources/single_input_two_identity_gates.json", - "input": [ - [ - 2, - 3 - ] - ], - "output": [ - [ - 2, - 3 - ], - [ - 2, - 3 - ] - ], - "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [ - [ - 0, - 0 - ] - ] - }, - { - "finalEvalProof": [ - 1 - ], - "partialSumPolys": [ - [ - -3, - -16 - ] - ] - }, - { - "finalEvalProof": [ - 1 - ], - "partialSumPolys": [ - [ - -3, - -16 - ] - ] - } - ] -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_input_two_outs_two_instances.json b/internal/generator/gkr/test_vectors/single_input_two_outs_two_instances.json deleted file mode 100644 index d348303d0..000000000 --- a/internal/generator/gkr/test_vectors/single_input_two_outs_two_instances.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "hash": { - "type": "const", - "val": -1 - }, - "circuit": "resources/single_input_two_outs.json", - "input": [ - [ - 1, - 2 - ] - ], - "output": [ - [ - 1, - 4 - ], - [ - 1, - 2 - ] - ], - "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [ - [ - 0, - 0 - ] - ] - }, - { - "finalEvalProof": [ - 0 - ], - "partialSumPolys": [ - [ - -4, - -36, - -112 - ] - ] - }, - { - "finalEvalProof": [ - 0 - ], - "partialSumPolys": [ - [ - -2, - -12 - ] - ] - } - ] -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_mimc_gate_four_instances.json b/internal/generator/gkr/test_vectors/single_mimc_gate_four_instances.json deleted file mode 100644 index ff275c9cb..000000000 --- a/internal/generator/gkr/test_vectors/single_mimc_gate_four_instances.json +++ /dev/null @@ -1,67 +0,0 @@ -{ - "hash": { - "type": "const", - "val": -1 - }, - "circuit": "resources/single_mimc_gate.json", - "input": [ - [ - 1, - 1, - 2, - 1 - ], - [ - 1, - 2, - 2, - 1 - ] - ], - "output": [ - [ - 128, - 2187, - 16384, - 128 - ] - ], - "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [ - -1, - -3 - ], - "partialSumPolys": [ - [ - -32640, - -2239484, - -29360128, - -200000010, - -931628672, - -3373267120, - -10200858624, - -26939400158 - ], - [ - -81920, - -41943040, - -1254113280, - -13421772800, - -83200000000, - -366917713920, - -1281828208640, - -3779571220480 - ] - ] - } - ] -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_mimc_gate_two_instances.json b/internal/generator/gkr/test_vectors/single_mimc_gate_two_instances.json deleted file mode 100644 index 369297dbd..000000000 --- a/internal/generator/gkr/test_vectors/single_mimc_gate_two_instances.json +++ /dev/null @@ -1,51 +0,0 @@ -{ - "hash": { - "type": "const", - "val": -1 - }, - "circuit": "resources/single_mimc_gate.json", - "input": [ - [ - 1, - 1 - ], - [ - 1, - 2 - ] - ], - "output": [ - [ - 128, - 2187 - ] - ], - "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [ - 1, - 0 - ], - "partialSumPolys": [ - [ - -2187, - -65536, - -546875, - -2799360, - -10706059, - -33554432, - -90876411, - -220000000 - ] - ] - } - ] -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_mul_gate_two_instances.json b/internal/generator/gkr/test_vectors/single_mul_gate_two_instances.json deleted file mode 100644 index 75c1d59c3..000000000 --- a/internal/generator/gkr/test_vectors/single_mul_gate_two_instances.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "hash": { - "type": "const", - "val": -1 - }, - "circuit": "resources/single_mul_gate.json", - "input": [ - [ - 4, - 3 - ], - [ - 2, - 3 - ] - ], - "output": [ - [ - 8, - 9 - ] - ], - "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [ - 5, - 1 - ], - "partialSumPolys": [ - [ - -9, - -32, - -35 - ] - ] - } - ] -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/internal/generator/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json deleted file mode 100644 index 10e5f1ff3..000000000 --- a/internal/generator/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "hash": { - "type": "const", - "val": -1 - }, - "circuit": "resources/two_identity_gates_composed_single_input.json", - "input": [ - [ - 2, - 1 - ] - ], - "output": [ - [ - 2, - 1 - ] - ], - "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [ - 3 - ], - "partialSumPolys": [ - [ - -1, - 0 - ] - ] - }, - { - "finalEvalProof": [ - 3 - ], - "partialSumPolys": [ - [ - -1, - 0 - ] - ] - } - ] -} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/internal/generator/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json deleted file mode 100644 index 19e127df7..000000000 --- a/internal/generator/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json +++ /dev/null @@ -1,45 +0,0 @@ -{ - "hash": { - "type": "const", - "val": -1 - }, - "circuit": "resources/two_inputs_select-input-3_gate.json", - "input": [ - [ - 0, - 1 - ], - [ - 2, - 3 - ] - ], - "output": [ - [ - 2, - 3 - ] - ], - "proof": [ - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [], - "partialSumPolys": [] - }, - { - "finalEvalProof": [ - -1, - 1 - ], - "partialSumPolys": [ - [ - -3, - -16 - ] - ] - } - ] -} \ No newline at end of file diff --git a/internal/generator/hash_to_field/generate.go b/internal/generator/hash_to_field/generate.go index d9c7dbbff..5ba0e4f25 100644 --- a/internal/generator/hash_to_field/generate.go +++ b/internal/generator/hash_to_field/generate.go @@ -1,10 +1,10 @@ package hash_to_field import ( + "github.com/consensys/gnark-crypto/field/generator/config" "path/filepath" "github.com/consensys/bavard" - "github.com/consensys/gnark-crypto/internal/generator/config" ) func Generate(conf config.FieldDependency, baseDir string, bgen *bavard.BatchGenerator) error { diff --git a/internal/generator/main.go b/internal/generator/main.go index 11e2b98db..e11bbdb40 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "github.com/consensys/gnark-crypto/internal/generator/gkr" "os" "os/exec" "path/filepath" @@ -31,8 +30,6 @@ import ( "github.com/consensys/gnark-crypto/internal/generator/plookup" "github.com/consensys/gnark-crypto/internal/generator/polynomial" "github.com/consensys/gnark-crypto/internal/generator/shplonk" - "github.com/consensys/gnark-crypto/internal/generator/sumcheck" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils" "github.com/consensys/gnark-crypto/internal/generator/tower" ) @@ -76,22 +73,15 @@ func main() { conf.FpUnusedBits = 64 - (conf.Fp.NbBits % 64) - frInfo := config.FieldDependency{ + frInfo := fieldConfig.FieldDependency{ FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + conf.Name + "/fr", FieldPackageName: "fr", ElementType: "fr.Element", } - gkrConfig := gkr.Config{ - FieldDependency: frInfo, - GenerateTests: true, - TestVectorsRelativePath: "../../../../internal/generator/gkr/test_vectors", - } - frOpts := []generator.Option{generator.WithASM(asmConfig)} if !(conf.Equal(config.STARK_CURVE) || conf.Equal(config.SECP256K1) || conf.Equal(config.GRUMPKIN)) { frOpts = append(frOpts, generator.WithFFT(fftConfig)) - gkrConfig.CanUseFFT = true } if conf.Equal(config.BLS12_377) { frOpts = append(frOpts, generator.WithSIS()) @@ -113,10 +103,6 @@ func main() { return } - // generate gkr on fr - // GKR tests use MiMC. Once SECP256K1 has a mimc implementation, we can generate GKR for it. - assertNoError(gkr.Generate(gkrConfig, filepath.Join(curveDir, "fr", "gkr"), bgen)) - // generate mimc on fr assertNoError(mimc.Generate(conf, filepath.Join(curveDir, "fr", "mimc"), bgen)) @@ -126,16 +112,7 @@ func main() { // generate polynomial on fr assertNoError(polynomial.Generate(frInfo, filepath.Join(curveDir, "fr", "polynomial"), true, bgen)) - // generate sumcheck on fr - assertNoError(sumcheck.Generate(frInfo, filepath.Join(curveDir, "fr", "sumcheck"), bgen)) - - // generate test vector utils on fr - assertNoError(test_vector_utils.Generate(test_vector_utils.Config{ - FieldDependency: frInfo, - RandomizeMissingHashEntries: false, - }, filepath.Join(curveDir, "fr", "test_vector_utils"), bgen)) - - fpInfo := config.FieldDependency{ + fpInfo := fieldConfig.FieldDependency{ FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + conf.Name + "/fp", FieldPackageName: "fp", ElementType: "fp.Element", @@ -205,11 +182,6 @@ func main() { } - wg.Add(1) - go func() { - defer wg.Done() - assertNoError(test_vector_utils.GenerateRationals(bgen)) - }() wg.Wait() // format the whole directory @@ -229,28 +201,6 @@ func main() { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr assertNoError(cmd.Run())*/ - - wg.Add(2) - - go func() { - // generate test vectors for sumcheck - cmd := exec.Command("go", "run", "./sumcheck/test_vectors") - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - assertNoError(cmd.Run()) - wg.Done() - }() - - go func() { - // generate test vectors for gkr - cmd := exec.Command("go", "run", "./gkr/test_vectors") - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - assertNoError(cmd.Run()) - wg.Done() - }() - - wg.Wait() } func assertNoError(err error) { diff --git a/internal/generator/polynomial/generate.go b/internal/generator/polynomial/generate.go index bdc2ab197..61421a017 100644 --- a/internal/generator/polynomial/generate.go +++ b/internal/generator/polynomial/generate.go @@ -1,10 +1,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/field/generator/config" "path/filepath" "github.com/consensys/bavard" - "github.com/consensys/gnark-crypto/internal/generator/config" ) func Generate(conf config.FieldDependency, baseDir string, generateTests bool, bgen *bavard.BatchGenerator) error { diff --git a/internal/generator/polynomial/template/pool.go.tmpl b/internal/generator/polynomial/template/pool.go.tmpl index d3e8acbad..d1119aeb5 100644 --- a/internal/generator/polynomial/template/pool.go.tmpl +++ b/internal/generator/polynomial/template/pool.go.tmpl @@ -1,38 +1,13 @@ -{{ $sham := eq .ElementType "small_rational.SmallRational"}} import ( -"{{.FieldPackagePath}}" -{{- if not $sham}} + "{{.FieldPackagePath}}" "encoding/json" "fmt" "runtime" "sort" "sync" "unsafe" -{{- end}} ) -{{ if $sham}} -// Do as little as possible to instantiate the interface -type Pool struct { -} - -func NewPool(...int) (pool Pool) { - return Pool{} -} - -func (p *Pool) Make(n int) []{{.ElementType}} { - return make([]{{.ElementType}}, n) -} - -func (p *Pool) Dump(...[]{{.ElementType}}) { -} - -func (p *Pool) Clone(slice []{{.ElementType}}) []{{.ElementType}} { - res := p.Make(len(slice)) - copy(res, slice) - return res -} -{{ else}} // Memory management for polynomials // WARNING: This is not thread safe TODO: Make sure that is not a problem // TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly @@ -205,5 +180,4 @@ func (p *Pool) Clone(slice []{{.ElementType}}) []{{.ElementType}} { res := p.Make(len(slice)) copy(res, slice) return res -} -{{end}} \ No newline at end of file +} \ No newline at end of file diff --git a/internal/generator/sumcheck/generate.go b/internal/generator/sumcheck/generate.go deleted file mode 100644 index 8c600221e..000000000 --- a/internal/generator/sumcheck/generate.go +++ /dev/null @@ -1,16 +0,0 @@ -package sumcheck - -import ( - "path/filepath" - - "github.com/consensys/bavard" - "github.com/consensys/gnark-crypto/internal/generator/config" -) - -func Generate(conf config.FieldDependency, baseDir string, bgen *bavard.BatchGenerator) error { - entries := []bavard.Entry{ - {File: filepath.Join(baseDir, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, - {File: filepath.Join(baseDir, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, - } - return bgen.Generate(conf, "sumcheck", "./sumcheck/template/", entries...) -} diff --git a/internal/generator/sumcheck/template/sumcheck.go.tmpl b/internal/generator/sumcheck/template/sumcheck.go.tmpl deleted file mode 100644 index 2ca7ec497..000000000 --- a/internal/generator/sumcheck/template/sumcheck.go.tmpl +++ /dev/null @@ -1,163 +0,0 @@ -import ( - "errors" - "{{.FieldPackagePath}}" - "{{.FieldPackagePath}}/polynomial" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a {{.ElementType}}) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next({{.ElementType}}) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []{{.ElementType}}) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a {{.ElementType}}) {{.ElementType}} // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []{{.ElementType}}, remainingChallengeNames *[]string) ({{.ElementType}}, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return {{.ElementType}}{}, err - } - } - var res {{.ElementType}} - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff {{.ElementType}} - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]{{.ElementType}}, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff {{.ElementType}} - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]{{.ElementType}}, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/internal/generator/sumcheck/template/sumcheck.test.go.tmpl b/internal/generator/sumcheck/template/sumcheck.test.go.tmpl deleted file mode 100644 index b50c31092..000000000 --- a/internal/generator/sumcheck/template/sumcheck.test.go.tmpl +++ /dev/null @@ -1,143 +0,0 @@ -import ( - "fmt" - "{{.FieldPackagePath}}" - "{{.FieldPackagePath}}/polynomial" - "{{.FieldPackagePath}}/test_vector_utils" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []{{.ElementType}}) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []{{.ElementType}}{sum} -} - -func (c singleMultilinClaim) Combine({{.ElementType}}) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r {{.ElementType}}) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum {{.ElementType}} -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs {{.ElementType}}) {{.ElementType}} { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/internal/generator/sumcheck/test_vectors/main.go b/internal/generator/sumcheck/test_vectors/main.go deleted file mode 100644 index 798f5a4f3..000000000 --- a/internal/generator/sumcheck/test_vectors/main.go +++ /dev/null @@ -1,199 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" - "hash" - "math/bits" - "os" - "path/filepath" -) - -func runMultilin(testCaseInfo *TestCaseInfo) error { - - var poly polynomial.MultiLin - if v, err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); err == nil { - poly = v - } else { - return err - } - - var hsh hash.Hash - var err error - if hsh, err = test_vector_utils.HashFromDescription(testCaseInfo.Hash); err != nil { - return err - } - - proof, err := sumcheck.Prove( - &singleMultilinClaim{poly}, fiatshamir.WithHash(hsh)) - if err != nil { - return err - } - testCaseInfo.Proof = toPrintableProof(proof) - - // Verification - if v, _err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); _err == nil { - poly = v - } else { - return _err - } - var claimedSum small_rational.SmallRational - if _, err = claimedSum.SetInterface(testCaseInfo.ClaimedSum); err != nil { - return err - } - - if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err != nil { - return fmt.Errorf("proof rejected: %v", err) - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func run(testCaseInfo *TestCaseInfo) error { - switch testCaseInfo.Type { - case "multilin": - return runMultilin(testCaseInfo) - default: - return fmt.Errorf("type \"%s\" unrecognized", testCaseInfo.Type) - } -} - -func runAll(relPath string) error { - var filename string - var err error - if filename, err = filepath.Abs(relPath); err != nil { - return err - } - - var bytes []byte - - if bytes, err = os.ReadFile(filename); err != nil { - return err - } - - var testCasesInfo TestCasesInfo - if err = json.Unmarshal(bytes, &testCasesInfo); err != nil { - return err - } - - failed := false - for name, testCase := range testCasesInfo { - if err = run(testCase); err != nil { - fmt.Println(name, ":", err) - failed = true - } - } - - if failed { - return fmt.Errorf("test case failed") - } - - if bytes, err = json.MarshalIndent(testCasesInfo, "", "\t"); err != nil { - return err - } - - return os.WriteFile(filename, bytes, 0) -} - -func main() { - if err := runAll("sumcheck/test_vectors/vectors.json"); err != nil { - fmt.Println(err) - os.Exit(-1) - } -} - -type TestCasesInfo map[string]*TestCaseInfo - -type TestCaseInfo struct { - Type string `json:"type"` - Hash test_vector_utils.HashDescription `json:"hash"` - Values []interface{} `json:"values"` - Description string `json:"description"` - Proof PrintableProof `json:"proof"` - ClaimedSum interface{} `json:"claimedSum"` -} - -type PrintableProof struct { - PartialSumPolys [][]interface{} `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` -} - -func toPrintableProof(proof sumcheck.Proof) (printable PrintableProof) { - if proof.FinalEvalProof != nil { - panic("null expected") - } - printable.FinalEvalProof = struct{}{} - printable.PartialSumPolys = test_vector_utils.ElementSliceSliceToInterfaceSliceSlice(proof.PartialSumPolys) - return -} - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval([]small_rational.SmallRational) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []small_rational.SmallRational{sum} -} - -func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum small_rational.SmallRational -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, _ interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(small_rational.SmallRational) small_rational.SmallRational { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} diff --git a/internal/generator/sumcheck/test_vectors/vectors.json b/internal/generator/sumcheck/test_vectors/vectors.json deleted file mode 100644 index 64b8e3fb2..000000000 --- a/internal/generator/sumcheck/test_vectors/vectors.json +++ /dev/null @@ -1,56 +0,0 @@ -{ - "linear_univariate_single_claim": { - "type": "multilin", - "hash": { - "type": "const", - "val": -1 - }, - "values": [ - 1, - 3 - ], - "description": "X ↦ 2X + 1", - "proof": { - "partialSumPolys": [ - [ - 3 - ] - ], - "finalEvalProof": {} - }, - "claimedSum": 4 - }, - "trilinear_single_claim": { - "type": "multilin", - "hash": { - "type": "const", - "val": -1 - }, - "values": [ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8 - ], - "description": "X₁, X₂, X₃ ↦ 1 + 4X₁ + 2X₂ + X₃", - "proof": { - "partialSumPolys": [ - [ - 26 - ], - [ - -1 - ], - [ - -4 - ] - ], - "finalEvalProof": {} - }, - "claimedSum": 36 - } -} \ No newline at end of file diff --git a/internal/generator/test_vector_utils/generate.go b/internal/generator/test_vector_utils/generate.go deleted file mode 100644 index 5b8bc05e9..000000000 --- a/internal/generator/test_vector_utils/generate.go +++ /dev/null @@ -1,59 +0,0 @@ -package test_vector_utils - -import ( - "path/filepath" - - "github.com/consensys/bavard" - "github.com/consensys/gnark-crypto/internal/generator/config" - "github.com/consensys/gnark-crypto/internal/generator/gkr" - "github.com/consensys/gnark-crypto/internal/generator/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/sumcheck" -) - -type Config struct { - config.FieldDependency - RandomizeMissingHashEntries bool -} - -func GenerateRationals(bgen *bavard.BatchGenerator) error { - gkrConf := gkr.Config{ - FieldDependency: config.FieldDependency{ - FieldPackagePath: "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational", - FieldPackageName: "small_rational", - ElementType: "small_rational.SmallRational", - }, - GenerateTests: false, - RetainTestCaseRawInfo: true, - CanUseFFT: false, - TestVectorsRelativePath: "../../../gkr/test_vectors", - } - - baseDir := "./test_vector_utils/small_rational/" - if err := polynomial.Generate(gkrConf.FieldDependency, baseDir+"polynomial", false, bgen); err != nil { - return err - } - if err := sumcheck.Generate(gkrConf.FieldDependency, baseDir+"sumcheck", bgen); err != nil { - return err - } - if err := gkr.Generate(gkrConf, baseDir+"gkr", bgen); err != nil { - return err - } - if err := Generate(Config{gkrConf.FieldDependency, true}, baseDir+"test_vector_utils", bgen); err != nil { - return err - } - - // generate gkr test vector generator for rationals - gkrConf.OutsideGkrPackage = true - return bgen.Generate(gkrConf, "main", "./gkr/template", bavard.Entry{ - File: filepath.Join("gkr", "test_vectors", "main.go"), Templates: []string{"gkr.test.vectors.gen.go.tmpl", "gkr.test.vectors.go.tmpl"}, - }) - -} - -func Generate(conf Config, baseDir string, bgen *bavard.BatchGenerator) error { - entry := bavard.Entry{ - File: filepath.Join(baseDir, "test_vector_utils.go"), Templates: []string{"test_vector_utils.go.tmpl"}, - } - - return bgen.Generate(conf, "test_vector_utils", "./test_vector_utils/template/", entry) -} diff --git a/internal/generator/test_vector_utils/small_rational/gkr/gkr.go b/internal/generator/test_vector_utils/small_rational/gkr/gkr.go deleted file mode 100644 index 3158701d3..000000000 --- a/internal/generator/test_vector_utils/small_rational/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...small_rational.SmallRational) small_rational.SmallRational - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]small_rational.SmallRational - claimedEvaluations []small_rational.SmallRational - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation small_rational.SmallRational - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]small_rational.SmallRational // x in the paper - claimedEvaluations []small_rational.SmallRational // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]small_rational.SmallRational, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step small_rational.SmallRational - - res := make([]small_rational.SmallRational, degGJ) - operands := make([]small_rational.SmallRational, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { - - //defer the proof, return list of claims - evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]small_rational.SmallRational, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { - res := make([]small_rational.SmallRational, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []small_rational.SmallRational - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []small_rational.SmallRational{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]small_rational.SmallRational) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []small_rational.SmallRational - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]small_rational.SmallRational) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// TopologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func TopologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := TopologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]small_rational.SmallRational, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]small_rational.SmallRational, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]small_rational.SmallRational) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/internal/generator/test_vector_utils/small_rational/gkr/registry.go b/internal/generator/test_vector_utils/small_rational/gkr/registry.go deleted file mode 100644 index 02c78e9bc..000000000 --- a/internal/generator/test_vector_utils/small_rational/gkr/registry.go +++ /dev/null @@ -1,374 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(small_rational.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]small_rational.SmallRational, nbIn) - consts := make(small_rational.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - x := make(small_rational.Vector, degreeBound) - x.MustSetRandom() - for i := range x { - fIn[0] = x[i] - for j := range consts { - fIn[j+1].Mul(&x[i], &consts[j]) - } - p[i] = f(fIn...) - } - - // obtain p's coefficients - p, err := interpolate(x, p) - if err != nil { - panic(err) - } - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) -// Note that the runtime is O(len(X)³) -func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { - if len(X) != len(Y) { - return nil, errors.New("X and Y must have the same length") - } - - // solve the system of equations by Gaussian elimination - augmentedRows := make([][]small_rational.SmallRational, len(X)) // the last column is the Y values - for i := range augmentedRows { - augmentedRows[i] = make([]small_rational.SmallRational, len(X)+1) - augmentedRows[i][0].SetOne() - augmentedRows[i][1].Set(&X[i]) - for j := 2; j < len(augmentedRows[i])-1; j++ { - augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) - } - augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) - } - - // make the upper triangle - for i := range len(augmentedRows) - 1 { - // use row i to eliminate the ith element in all rows below - var negInv small_rational.SmallRational - if augmentedRows[i][i].IsZero() { - return nil, errors.New("singular matrix") - } - negInv.Inverse(&augmentedRows[i][i]) - negInv.Neg(&negInv) - for j := i + 1; j < len(augmentedRows); j++ { - var c small_rational.SmallRational - c.Mul(&augmentedRows[j][i], &negInv) - // augmentedRows[j][i].SetZero() omitted - for k := i + 1; k < len(augmentedRows[i]); k++ { - var t small_rational.SmallRational - t.Mul(&augmentedRows[i][k], &c) - augmentedRows[j][k].Add(&augmentedRows[j][k], &t) - } - } - } - - // back substitution - res := make(polynomial.Polynomial, len(X)) - for i := len(augmentedRows) - 1; i >= 0; i-- { - res[i] = augmentedRows[i][len(augmentedRows[i])-1] - for j := i + 1; j < len(augmentedRows[i])-1; j++ { - var t small_rational.SmallRational - t.Mul(&res[j], &augmentedRows[i][j]) - res[i].Sub(&res[i], &t) - } - res[i].Div(&res[i], &augmentedRows[i][i]) - } - - return res, nil -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...small_rational.SmallRational) small_rational.SmallRational { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/internal/generator/test_vector_utils/small_rational/polynomial/doc.go b/internal/generator/test_vector_utils/small_rational/polynomial/doc.go deleted file mode 100644 index ead3b5cba..000000000 --- a/internal/generator/test_vector_utils/small_rational/polynomial/doc.go +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -// Package polynomial provides polynomial methods and commitment schemes. -package polynomial diff --git a/internal/generator/test_vector_utils/small_rational/polynomial/multilin.go b/internal/generator/test_vector_utils/small_rational/polynomial/multilin.go deleted file mode 100644 index 87b46a7ed..000000000 --- a/internal/generator/test_vector_utils/small_rational/polynomial/multilin.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package polynomial - -import ( - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/utils" - "math/bits" -) - -// MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial -// The variables are X₁ through Xₙ where n = log(len(.)) -// .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ) -// It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial -type MultiLin []small_rational.SmallRational - -// Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r -func (m *MultiLin) Fold(r small_rational.SmallRational) { - mid := len(*m) / 2 - - bottom, top := (*m)[:mid], (*m)[mid:] - - var t small_rational.SmallRational // no need to update the top part - - // updating bookkeeping table - // knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0)) - // the following loop computes the evaluations of f(r) accordingly: - // f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ)) - for i := 0; i < mid; i++ { - // table[i] ← table[i] + r (table[i + mid] - table[i]) - t.Sub(&top[i], &bottom[i]) - t.Mul(&t, &r) - bottom[i].Add(&bottom[i], &t) - } - - *m = (*m)[:mid] -} - -func (m *MultiLin) FoldParallel(r small_rational.SmallRational) utils.Task { - mid := len(*m) / 2 - bottom, top := (*m)[:mid], (*m)[mid:] - - *m = bottom - - return func(start, end int) { - var t small_rational.SmallRational // no need to update the top part - for i := start; i < end; i++ { - // table[i] ← table[i] + r (table[i + mid] - table[i]) - t.Sub(&top[i], &bottom[i]) - t.Mul(&t, &r) - bottom[i].Add(&bottom[i], &t) - } - } -} - -func (m MultiLin) Sum() small_rational.SmallRational { - s := m[0] - for i := 1; i < len(m); i++ { - s.Add(&s, &m[i]) - } - return s -} - -func _clone(m MultiLin, p *Pool) MultiLin { - if p == nil { - return m.Clone() - } else { - return p.Clone(m) - } -} - -func _dump(m MultiLin, p *Pool) { - if p != nil { - p.Dump(m) - } -} - -// Evaluate extrapolate the value of the multilinear polynomial corresponding to m -// on the given coordinates -func (m MultiLin) Evaluate(coordinates []small_rational.SmallRational, p *Pool) small_rational.SmallRational { - // Folding is a mutating operation - bkCopy := _clone(m, p) - - // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) - for _, r := range coordinates { - bkCopy.Fold(r) - } - - result := bkCopy[0] - - _dump(bkCopy, p) - return result -} - -// Clone creates a deep copy of a bookkeeping table. -// Both multilinear interpolation and sumcheck require folding an underlying -// array, but folding changes the array. To do both one requires a deep copy -// of the bookkeeping table. -func (m MultiLin) Clone() MultiLin { - res := make(MultiLin, len(m)) - copy(res, m) - return res -} - -// Add two bookKeepingTables -func (m *MultiLin) Add(left, right MultiLin) { - size := len(left) - // Check that left and right have the same size - if len(right) != size || len(*m) != size { - panic("left, right and destination must have the right size") - } - - // Add elementwise - for i := 0; i < size; i++ { - (*m)[i].Add(&left[i], &right[i]) - } -} - -// EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) -// where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| -// -// x -// -// In other words the polynomial evaluated here is the multilinear extrapolation of -// one that evaluates to q' == h' for vectors q', h' of binary values -func EvalEq(q, h []small_rational.SmallRational) small_rational.SmallRational { - var res, nxt, one, sum small_rational.SmallRational - one.SetOne() - for i := 0; i < len(q); i++ { - nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ - nxt.Double(&nxt) // nxt <- 2 * qᵢ * hᵢ - nxt.Add(&nxt, &one) // nxt <- 1 + 2 * qᵢ * hᵢ - sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ TODO: Why not subtract one by one from nxt? More parallel? - - if i == 0 { - res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ - } else { - nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ - res.Mul(&res, &nxt) // res <- res * nxt - } - } - return res -} - -// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] -func (m *MultiLin) Eq(q []small_rational.SmallRational) { - n := len(q) - - if len(*m) != 1<= 0; i-- { - res.Mul(&res, v) - res.Add(&res, &(*p)[i]) - } - - return res -} - -// Clone returns a copy of the polynomial -func (p *Polynomial) Clone() Polynomial { - _p := make(Polynomial, len(*p)) - copy(_p, *p) - return _p -} - -// Set to another polynomial -func (p *Polynomial) Set(p1 Polynomial) { - if len(*p) != len(p1) { - *p = p1.Clone() - return - } - - for i := 0; i < len(p1); i++ { - (*p)[i].Set(&p1[i]) - } -} - -// AddConstantInPlace adds a constant to the polynomial, modifying p -func (p *Polynomial) AddConstantInPlace(c *small_rational.SmallRational) { - for i := 0; i < len(*p); i++ { - (*p)[i].Add(&(*p)[i], c) - } -} - -// SubConstantInPlace subs a constant to the polynomial, modifying p -func (p *Polynomial) SubConstantInPlace(c *small_rational.SmallRational) { - for i := 0; i < len(*p); i++ { - (*p)[i].Sub(&(*p)[i], c) - } -} - -// ScaleInPlace multiplies p by v, modifying p -func (p *Polynomial) ScaleInPlace(c *small_rational.SmallRational) { - for i := 0; i < len(*p); i++ { - (*p)[i].Mul(&(*p)[i], c) - } -} - -// Scale multiplies p0 by v, storing the result in p -func (p *Polynomial) Scale(c *small_rational.SmallRational, p0 Polynomial) { - if len(*p) != len(p0) { - *p = make(Polynomial, len(p0)) - } - for i := 0; i < len(p0); i++ { - (*p)[i].Mul(c, &p0[i]) - } -} - -// Add adds p1 to p2 -// This function allocates a new slice unless p == p1 or p == p2 -func (p *Polynomial) Add(p1, p2 Polynomial) *Polynomial { - - bigger := p1 - smaller := p2 - if len(bigger) < len(smaller) { - bigger, smaller = smaller, bigger - } - - if len(*p) == len(bigger) && (&(*p)[0] == &bigger[0]) { - for i := 0; i < len(smaller); i++ { - (*p)[i].Add(&(*p)[i], &smaller[i]) - } - return p - } - - if len(*p) == len(smaller) && (&(*p)[0] == &smaller[0]) { - for i := 0; i < len(smaller); i++ { - (*p)[i].Add(&(*p)[i], &bigger[i]) - } - *p = append(*p, bigger[len(smaller):]...) - return p - } - - res := make(Polynomial, len(bigger)) - copy(res, bigger) - for i := 0; i < len(smaller); i++ { - res[i].Add(&res[i], &smaller[i]) - } - *p = res - return p -} - -// Sub subtracts p2 from p1 -// TODO make interface more consistent with Add -func (p *Polynomial) Sub(p1, p2 Polynomial) *Polynomial { - if len(p1) != len(p2) || len(p2) != len(*p) { - return nil - } - for i := 0; i < len(*p); i++ { - (*p)[i].Sub(&p1[i], &p2[i]) - } - return p -} - -// Equal checks equality between two polynomials -func (p *Polynomial) Equal(p1 Polynomial) bool { - if (*p == nil) != (p1 == nil) { - return false - } - - if len(*p) != len(p1) { - return false - } - - for i := range p1 { - if !(*p)[i].Equal(&p1[i]) { - return false - } - } - - return true -} - -func (p Polynomial) SetZero() { - for i := 0; i < len(p); i++ { - p[i].SetZero() - } -} - -func (p Polynomial) Text(base int) string { - - var builder strings.Builder - - first := true - for d := len(p) - 1; d >= 0; d-- { - if p[d].IsZero() { - continue - } - - pD := p[d] - pDText := pD.Text(base) - - initialLen := builder.Len() - - if pDText[0] == '-' { - pDText = pDText[1:] - if first { - builder.WriteString("-") - } else { - builder.WriteString(" - ") - } - } else if !first { - builder.WriteString(" + ") - } - - first = false - - if !pD.IsOne() || d == 0 { - builder.WriteString(pDText) - } - - if builder.Len()-initialLen > 10 { - builder.WriteString("×") - } - - if d != 0 { - builder.WriteString("X") - } - if d > 1 { - builder.WriteString( - utils.ToSuperscript(strconv.Itoa(d)), - ) - } - - } - - if first { - return "0" - } - - return builder.String() -} - -// InterpolateOnRange maps vector v to polynomial f -// such that f(i) = v[i] for 0 ≤ i < len(v). -// len(f) = len(v) and deg(f) ≤ len(v) - 1 -func InterpolateOnRange(v []small_rational.SmallRational) Polynomial { - nEvals := uint8(len(v)) - if int(nEvals) != len(v) { - panic("interpolation method too inefficient for nEvals > 255") - } - lagrange := getLagrangeBasis(nEvals) - - var res Polynomial - res.Scale(&v[0], lagrange[0]) - - temp := make(Polynomial, nEvals) - - for i := uint8(1); i < nEvals; i++ { - temp.Scale(&v[i], lagrange[i]) - res.Add(res, temp) - } - - return res -} - -// lagrange bases used by InterpolateOnRange -var lagrangeBasis sync.Map - -func getLagrangeBasis(domainSize uint8) []Polynomial { - if res, ok := lagrangeBasis.Load(domainSize); ok { - return res.([]Polynomial) - } - - // not found. compute - var res []Polynomial - if domainSize >= 2 { - res = computeLagrangeBasis(domainSize) - } else if domainSize == 1 { - res = []Polynomial{make(Polynomial, 1)} - res[0][0].SetOne() - } - lagrangeBasis.Store(domainSize, res) - - return res -} - -// computeLagrangeBasis precomputes in explicit coefficient form for each 0 ≤ l < domainSize the polynomial -// pₗ := X (X-1) ... (X-l-1) (X-l+1) ... (X - domainSize + 1) / ( l (l-1) ... 2 (-1) ... (l - domainSize +1) ) -// Note that pₗ(l) = 1 and pₗ(n) = 0 if 0 ≤ l < domainSize, n ≠ l -func computeLagrangeBasis(domainSize uint8) []Polynomial { - - constTerms := make([]small_rational.SmallRational, domainSize) - for i := uint8(0); i < domainSize; i++ { - constTerms[i].SetInt64(-int64(i)) - } - - res := make([]Polynomial, domainSize) - multScratch := make(Polynomial, domainSize-1) - - // compute pₗ - for l := uint8(0); l < domainSize; l++ { - - // TODO @Tabaie Optimize this with some trees? O(log(domainSize)) polynomial mults instead of O(domainSize)? Then again it would be fewer big poly mults vs many small poly mults - d := uint8(0) //d is the current degree of res - for i := uint8(0); i < domainSize; i++ { - if i == l { - continue - } - if d == 0 { - res[l] = make(Polynomial, domainSize) - res[l][domainSize-2] = constTerms[i] - res[l][domainSize-1].SetOne() - } else { - current := res[l][domainSize-d-2:] - timesConst := multScratch[domainSize-d-2:] - - timesConst.Scale(&constTerms[i], current[1:]) //TODO: Directly double and add since constTerms are tiny? (even less than 4 bits) - nonLeading := current[0 : d+1] - - nonLeading.Add(nonLeading, timesConst) - - } - d++ - } - - } - - // We have pₗ(i≠l)=0. Now scale so that pₗ(l)=1 - // Replace the constTerms with norms - for l := uint8(0); l < domainSize; l++ { - constTerms[l].Neg(&constTerms[l]) - constTerms[l] = res[l].Eval(&constTerms[l]) - } - constTerms = small_rational.BatchInvert(constTerms) - for l := uint8(0); l < domainSize; l++ { - res[l].ScaleInPlace(&constTerms[l]) - } - - return res -} diff --git a/internal/generator/test_vector_utils/small_rational/polynomial/pool.go b/internal/generator/test_vector_utils/small_rational/polynomial/pool.go deleted file mode 100644 index 688b19390..000000000 --- a/internal/generator/test_vector_utils/small_rational/polynomial/pool.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package polynomial - -import ( - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" -) - -// Do as little as possible to instantiate the interface -type Pool struct { -} - -func NewPool(...int) (pool Pool) { - return Pool{} -} - -func (p *Pool) Make(n int) []small_rational.SmallRational { - return make([]small_rational.SmallRational, n) -} - -func (p *Pool) Dump(...[]small_rational.SmallRational) { -} - -func (p *Pool) Clone(slice []small_rational.SmallRational) []small_rational.SmallRational { - res := p.Make(len(slice)) - copy(res, slice) - return res -} diff --git a/internal/generator/test_vector_utils/small_rational/small-rational.go b/internal/generator/test_vector_utils/small_rational/small-rational.go deleted file mode 100644 index 3015ba7eb..000000000 --- a/internal/generator/test_vector_utils/small_rational/small-rational.go +++ /dev/null @@ -1,452 +0,0 @@ -package small_rational - -import ( - "crypto/rand" - "fmt" - "math/big" - "strconv" - "strings" -) - -const Bytes = 64 - -type SmallRational struct { - text string //For debugging purposes - numerator big.Int - denominator big.Int // By convention, denominator == 0 also indicates zero -} - -var smallPrimes = []*big.Int{ - big.NewInt(2), big.NewInt(3), big.NewInt(5), - big.NewInt(7), big.NewInt(11), big.NewInt(13), -} - -func bigDivides(p, a *big.Int) bool { - var remainder big.Int - remainder.Mod(a, p) - return remainder.BitLen() == 0 -} - -func (z *SmallRational) UpdateText() { - z.text = z.Text(10) -} - -func (z *SmallRational) simplify() { - - if z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 { - return - } - - var num, den big.Int - - num.Set(&z.numerator) - den.Set(&z.denominator) - - for _, p := range smallPrimes { - for bigDivides(p, &num) && bigDivides(p, &den) { - num.Div(&num, p) - den.Div(&den, p) - } - } - - if bigDivides(&den, &num) { - num.Div(&num, &den) - den.SetInt64(1) - } - - z.numerator = num - z.denominator = den - -} -func (z *SmallRational) Square(x *SmallRational) *SmallRational { - var num, den big.Int - num.Mul(&x.numerator, &x.numerator) - den.Mul(&x.denominator, &x.denominator) - - z.numerator = num - z.denominator = den - - z.UpdateText() - - return z -} - -func (z *SmallRational) String() string { - z.text = z.Text(10) - return z.text -} - -func (z *SmallRational) Add(x, y *SmallRational) *SmallRational { - if x.denominator.BitLen() == 0 { - *z = *y - } else if y.denominator.BitLen() == 0 { - *z = *x - } else { - //TODO: Exploit cases where one denom divides the other - var numDen, denNum big.Int - numDen.Mul(&x.numerator, &y.denominator) - denNum.Mul(&x.denominator, &y.numerator) - - numDen.Add(&denNum, &numDen) - z.numerator = numDen //to avoid shallow copy problems - - denNum.Mul(&x.denominator, &y.denominator) - z.denominator = denNum - z.simplify() - } - - z.UpdateText() - - return z -} - -func (z *SmallRational) IsZero() bool { - return z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 -} - -func (z *SmallRational) Inverse(x *SmallRational) *SmallRational { - if x.IsZero() { - *z = *x - } else { - *z = SmallRational{numerator: x.denominator, denominator: x.numerator} - z.UpdateText() - } - - return z -} - -func (z *SmallRational) Neg(x *SmallRational) *SmallRational { - z.numerator.Neg(&x.numerator) - z.denominator = x.denominator - - if x.text == "" { - x.UpdateText() - } - - if x.text[0] == '-' { - z.text = x.text[1:] - } else { - z.text = "-" + x.text - } - - return z -} - -func (z *SmallRational) Double(x *SmallRational) *SmallRational { - - var y big.Int - - if x.denominator.Bit(0) == 0 { - z.numerator = x.numerator - y.Rsh(&x.denominator, 1) - z.denominator = y - } else { - y.Lsh(&x.numerator, 1) - z.numerator = y - z.denominator = x.denominator - } - - z.UpdateText() - - return z -} - -func (z *SmallRational) Sign() int { - return z.numerator.Sign() * z.denominator.Sign() -} - -func (z *SmallRational) MarshalJSON() ([]byte, error) { - return []byte(z.String()), nil -} - -func (z *SmallRational) UnmarshalJson(data []byte) error { - _, err := z.SetInterface(string(data)) - return err -} - -func (z *SmallRational) Equal(x *SmallRational) bool { - return z.Cmp(x) == 0 -} - -func (z *SmallRational) Sub(x, y *SmallRational) *SmallRational { - var yNeg SmallRational - yNeg.Neg(y) - z.Add(x, &yNeg) - - z.UpdateText() - return z -} - -func (z *SmallRational) Cmp(x *SmallRational) int { - zSign, xSign := z.Sign(), x.Sign() - - if zSign > xSign { - return 1 - } - if zSign < xSign { - return -1 - } - - var Z, X big.Int - Z.Mul(&z.numerator, &x.denominator) - X.Mul(&x.numerator, &z.denominator) - - Z.Abs(&Z) - X.Abs(&X) - - return Z.Cmp(&X) * zSign - -} - -func BatchInvert(a []SmallRational) []SmallRational { - res := make([]SmallRational, len(a)) - for i := range a { - res[i].Inverse(&a[i]) - } - return res -} - -func (z *SmallRational) Mul(x, y *SmallRational) *SmallRational { - var num, den big.Int - - num.Mul(&x.numerator, &y.numerator) - den.Mul(&x.denominator, &y.denominator) - - z.numerator = num - z.denominator = den - - z.simplify() - z.UpdateText() - return z -} - -func (z *SmallRational) Div(x, y *SmallRational) *SmallRational { - var num, den big.Int - - num.Mul(&x.numerator, &y.denominator) - den.Mul(&x.denominator, &y.numerator) - - z.numerator = num - z.denominator = den - - z.simplify() - z.UpdateText() - return z -} - -func (z *SmallRational) Halve() *SmallRational { - if z.numerator.Bit(0) == 0 { - z.numerator.Rsh(&z.numerator, 1) - } else { - z.denominator.Lsh(&z.denominator, 1) - } - - z.simplify() - z.UpdateText() - return z -} - -func (z *SmallRational) SetOne() *SmallRational { - return z.SetInt64(1) -} - -func (z *SmallRational) SetZero() *SmallRational { - return z.SetInt64(0) -} - -func (z *SmallRational) SetInt64(i int64) *SmallRational { - z.numerator = *big.NewInt(i) - z.denominator = *big.NewInt(1) - z.text = strconv.FormatInt(i, 10) - return z -} - -func (z *SmallRational) SetRandom() (*SmallRational, error) { - - bytes := make([]byte, 1) - n, err := rand.Read(bytes) - if err != nil { - return nil, err - } - if n != len(bytes) { - return nil, fmt.Errorf("%d bytes read instead of %d", n, len(bytes)) - } - - z.numerator = *big.NewInt(int64(bytes[0]%16) - 8) - z.denominator = *big.NewInt(int64((bytes[0]) / 16)) - - z.simplify() - z.UpdateText() - - return z, nil -} - -func (z *SmallRational) MustSetRandom() *SmallRational { - if _, err := z.SetRandom(); err != nil { - panic(err) - } - return z -} - -func (z *SmallRational) SetUint64(i uint64) { - var num big.Int - num.SetUint64(i) - z.numerator = num - z.denominator = *big.NewInt(1) - z.text = strconv.FormatUint(i, 10) -} - -func (z *SmallRational) IsOne() bool { - return z.numerator.Cmp(&z.denominator) == 0 && z.denominator.BitLen() != 0 -} - -func (z *SmallRational) Text(base int) string { - - if z.denominator.BitLen() == 0 { - return "0" - } - - if z.denominator.Sign() < 0 { - var num, den big.Int - num.Neg(&z.numerator) - den.Neg(&z.denominator) - z.numerator = num - z.denominator = den - } - - if bigDivides(&z.denominator, &z.numerator) { - var num big.Int - num.Div(&z.numerator, &z.denominator) - z.numerator = num - z.denominator = *big.NewInt(1) - } - - numerator := z.numerator.Text(base) - - if z.denominator.IsInt64() && z.denominator.Int64() == 1 { - return numerator - } - - return numerator + "/" + z.denominator.Text(base) -} - -func (z *SmallRational) Set(x *SmallRational) *SmallRational { - *z = *x // shallow copy is safe because ops are never in place - return z -} - -func (z *SmallRational) SetInterface(x interface{}) (*SmallRational, error) { - - switch v := x.(type) { - case *SmallRational: - *z = *v - case SmallRational: - *z = v - case int64: - z.SetInt64(v) - case int: - z.SetInt64(int64(v)) - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - case string: - z.text = v - sep := strings.Split(v, "/") - switch len(sep) { - case 1: - if asInt, err := strconv.Atoi(sep[0]); err == nil { - z.SetInt64(int64(asInt)) - } else { - return nil, err - } - case 2: - var err error - var num, denom int - num, err = strconv.Atoi(sep[0]) - if err != nil { - return nil, err - } - denom, err = strconv.Atoi(sep[1]) - if err != nil { - return nil, err - } - z.numerator = *big.NewInt(int64(num)) - z.denominator = *big.NewInt(int64(denom)) - default: - return nil, fmt.Errorf("cannot parse \"%s\"", v) - } - default: - return nil, fmt.Errorf("cannot parse %T", x) - } - - return z, nil -} - -func bigIntToBytesSigned(dst []byte, src big.Int) { - src.FillBytes(dst[1:]) - dst[0] = 0 - if src.Sign() < 0 { - dst[0] = 255 - } -} - -func (z *SmallRational) Bytes() [Bytes]byte { - var res [Bytes]byte - bigIntToBytesSigned(res[:Bytes/2], z.numerator) - bigIntToBytesSigned(res[Bytes/2:], z.denominator) - return res -} - -func bytesToBigIntSigned(src []byte) big.Int { - var res big.Int - res.SetBytes(src[1:]) - if src[0] != 0 { - res.Neg(&res) - } - return res -} - -// BigInt returns sets dst to the value of z if it is an integer. -// if z is not an integer, nil is returned. -// if the given dst is nil, the address of the numerator is returned. -// if the given dst is non-nil, it is returned. -func (z *SmallRational) BigInt(dst *big.Int) *big.Int { - if z.denominator.Cmp(big.NewInt(1)) != 0 { - return nil - } - if dst == nil { - return &z.numerator - } - dst.Set(&z.numerator) - return dst -} - -func (z *SmallRational) SetBytes(b []byte) { - if len(b) > Bytes/2 { - z.numerator = bytesToBigIntSigned(b[:Bytes/2]) - z.denominator = bytesToBigIntSigned(b[Bytes/2:]) - } else { - z.numerator.SetBytes(b) - z.denominator.SetInt64(1) - } - z.simplify() - z.UpdateText() -} - -func One() SmallRational { - res := SmallRational{ - text: "1", - } - res.numerator.SetInt64(1) - res.denominator.SetInt64(1) - return res -} - -func Modulus() *big.Int { - res := big.NewInt(1) - res.Lsh(res, 64) - return res -} diff --git a/internal/generator/test_vector_utils/small_rational/small_rational_test.go b/internal/generator/test_vector_utils/small_rational/small_rational_test.go deleted file mode 100644 index 6d7733ea7..000000000 --- a/internal/generator/test_vector_utils/small_rational/small_rational_test.go +++ /dev/null @@ -1,115 +0,0 @@ -package small_rational - -import ( - "github.com/stretchr/testify/assert" - "math/big" - "testing" -) - -func TestBigDivides(t *testing.T) { - assert.True(t, bigDivides(big.NewInt(-1), big.NewInt(4))) - assert.False(t, bigDivides(big.NewInt(-3), big.NewInt(4))) -} - -func TestCmp(t *testing.T) { - - cases := make([]SmallRational, 36) - - for i := int64(0); i < 9; i++ { - if i%2 == 0 { - cases[4*i].numerator.SetInt64((i - 4) / 2) - cases[4*i].denominator.SetInt64(1) - } else { - cases[4*i].numerator.SetInt64(i - 4) - cases[4*i].denominator.SetInt64(2) - } - - cases[4*i+1].numerator.Neg(&cases[4*i].numerator) - cases[4*i+1].denominator.Neg(&cases[4*i].denominator) - - cases[4*i+2].numerator.Lsh(&cases[4*i].numerator, 1) - cases[4*i+2].denominator.Lsh(&cases[4*i].denominator, 1) - - cases[4*i+3].numerator.Neg(&cases[4*i+2].numerator) - cases[4*i+3].denominator.Neg(&cases[4*i+2].denominator) - } - - for i := range cases { - for j := range cases { - I, J := i/4, j/4 - var expectedCmp int - cmp := cases[i].Cmp(&cases[j]) - if I < J { - expectedCmp = -1 - } else if I == J { - expectedCmp = 0 - } else { - expectedCmp = 1 - } - assert.Equal(t, expectedCmp, cmp, "comparing index %d, index %d", i, j) - } - } - - zeroIndex := len(cases) / 8 - var weirdZero SmallRational - for i := range cases { - I := i / 4 - var expectedCmp int - cmp := cases[i].Cmp(&weirdZero) - cmpNeg := weirdZero.Cmp(&cases[i]) - if I < zeroIndex { - expectedCmp = -1 - } else if I == zeroIndex { - expectedCmp = 0 - } else { - expectedCmp = 1 - } - - assert.Equal(t, expectedCmp, cmp, "comparing index %d, 0/0", i) - assert.Equal(t, -expectedCmp, cmpNeg, "comparing 0/0, index %d", i) - } -} - -func TestDouble(t *testing.T) { - values := []interface{}{1, 2, 3, 4, 5, "2/3", "3/2", "-3/-2"} - valsDoubled := []interface{}{2, 4, 6, 8, 10, "-4/-3", 3, 3} - - for i := range values { - var v, vDoubled, vDoubledExpected SmallRational - _, err := v.SetInterface(values[i]) - assert.NoError(t, err) - _, err = vDoubledExpected.SetInterface(valsDoubled[i]) - assert.NoError(t, err) - vDoubled.Double(&v) - assert.True(t, vDoubled.Equal(&vDoubledExpected), - "mismatch at %d: expected 2×%s = %s, saw %s", i, v.text, vDoubledExpected.text, vDoubled.text) - - } -} - -func TestOperandConstancy(t *testing.T) { - var p0, p, pPure SmallRational - p0.SetInt64(1) - p.SetInt64(-3) - pPure.SetInt64(-3) - - res := p - res.Add(&res, &p0) - assert.True(t, p.Equal(&pPure)) -} - -func TestSquare(t *testing.T) { - var two, four, x SmallRational - two.SetInt64(2) - four.SetInt64(4) - - x.Square(&two) - - assert.True(t, x.Equal(&four), "expected 4, saw %s", x.Text(10)) -} - -func TestSetBytes(t *testing.T) { - var c SmallRational - c.SetBytes([]byte("firstChallenge.0")) - -} diff --git a/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck.go b/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck.go deleted file mode 100644 index eb8e62c99..000000000 --- a/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "errors" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []small_rational.SmallRational) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return small_rational.SmallRational{}, err - } - } - var res small_rational.SmallRational - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff small_rational.SmallRational - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]small_rational.SmallRational, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff small_rational.SmallRational - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]small_rational.SmallRational, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck_test.go b/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck_test.go deleted file mode 100644 index ad6f8ac23..000000000 --- a/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package sumcheck - -import ( - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []small_rational.SmallRational{sum} -} - -func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum small_rational.SmallRational -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs small_rational.SmallRational) small_rational.SmallRational { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils.go b/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils.go deleted file mode 100644 index 9e91fe7c6..000000000 --- a/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils.go +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package test_vector_utils - -import ( - "fmt" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "hash" - "reflect" -) - -func ToElement(i int64) *small_rational.SmallRational { - var res small_rational.SmallRational - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter{startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res small_rational.SmallRational - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return small_rational.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return small_rational.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []small_rational.SmallRational - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return small_rational.Bytes -} - -func (h *ListHash) BlockSize() int { - return small_rational.Bytes -} - -func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { - elementSlice := make([]small_rational.SmallRational, len(slice)) - for i, v := range slice { - if _, err := elementSlice[i].SetInterface(v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i], b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *small_rational.SmallRational) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().(small_rational.SmallRational) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/internal/generator/test_vector_utils/small_rational/vector.go b/internal/generator/test_vector_utils/small_rational/vector.go deleted file mode 100644 index 07fcc3aff..000000000 --- a/internal/generator/test_vector_utils/small_rational/vector.go +++ /dev/null @@ -1,9 +0,0 @@ -package small_rational - -type Vector []SmallRational - -func (v Vector) MustSetRandom() { - for i := range v { - v[i].MustSetRandom() - } -} diff --git a/internal/generator/test_vector_utils/template/test_vector_utils.go.tmpl b/internal/generator/test_vector_utils/template/test_vector_utils.go.tmpl deleted file mode 100644 index 5b7495eec..000000000 --- a/internal/generator/test_vector_utils/template/test_vector_utils.go.tmpl +++ /dev/null @@ -1,220 +0,0 @@ -import ( - "fmt" - "{{.FieldPackagePath}}" - "{{.FieldPackagePath}}/polynomial" - "hash" - "reflect" - {{if eq .ElementType "fr.Element"}}"strings"{{- end}} -) - -func ToElement(i int64) *{{.ElementType}} { - var res {{.ElementType}} - res.SetInt64(i) - return &res -} - -type HashDescription map[string]interface{} - -func HashFromDescription(d HashDescription) (hash.Hash, error) { - if _type, ok := d["type"]; ok { - switch _type { - case "const": - startState := int64(d["val"].(float64)) - return &MessageCounter {startState: startState, step: 0, state: startState}, nil - default: - return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) - } - } - return nil, fmt.Errorf("hash description missing type") -} - -type MessageCounter struct { - startState int64 - state int64 - step int64 -} - -func (m *MessageCounter) Write(p []byte) (n int, err error) { - inputBlockSize := (len(p)-1)/{{.FieldPackageName}}.Bytes + 1 - m.state += int64(inputBlockSize) * m.step - return len(p), nil -} - -func (m *MessageCounter) Sum(b []byte) []byte { - inputBlockSize := (len(b)-1)/{{.FieldPackageName}}.Bytes + 1 - resI := m.state + int64(inputBlockSize)*m.step - var res {{.ElementType}} - res.SetInt64(int64(resI)) - resBytes := res.Bytes() - return resBytes[:] -} - -func (m *MessageCounter) Reset() { - m.state = m.startState -} - -func (m *MessageCounter) Size() int { - return {{.FieldPackageName}}.Bytes -} - -func (m *MessageCounter) BlockSize() int { - return {{.FieldPackageName}}.Bytes -} - -func NewMessageCounter(startState, step int) hash.Hash { - transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} - return transcript -} - -func NewMessageCounterGenerator(startState, step int) func() hash.Hash { - return func() hash.Hash { - return NewMessageCounter(startState, step) - } -} - -type ListHash []{{.ElementType}} - -func (h *ListHash) Write(p []byte) (n int, err error) { - return len(p), nil -} - -func (h *ListHash) Sum(b []byte) []byte { - res := (*h)[0].Bytes() - *h = (*h)[1:] - return res[:] -} - -func (h *ListHash) Reset() { -} - -func (h *ListHash) Size() int { - return {{.FieldPackageName}}.Bytes -} - -func (h *ListHash) BlockSize() int { -return {{.FieldPackageName}}.Bytes -} - -{{- if eq .ElementType "fr.Element"}} -func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { - - // TODO: Put this in element.SetString? - switch v := value.(type) { - case string: - - if sep := strings.Split(v, "/"); len(sep) == 2 { - var denom fr.Element - if _, err := z.SetString(sep[0]); err != nil { - return nil, err - } - if _, err := denom.SetString(sep[1]); err != nil { - return nil, err - } - denom.Inverse(&denom) - z.Mul(z, &denom) - return z, nil - } - - case float64: - asInt := int64(v) - if float64(asInt) != v { - return nil, fmt.Errorf("cannot currently parse float") - } - z.SetInt64(asInt) - return z, nil - } - - return z.SetInterface(value) -} -{{- end}} - -{{- define "setElement element value elementType"}} -{{- if eq .elementType "fr.Element"}} SetElement(&{{.element}}, {{.value}}) -{{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}}) -{{- else}} - {{print "\"UNEXPECTED TYPE" .elementType "\""}} -{{- end}} -{{- end}} - -func SliceToElementSlice[T any](slice []T) ([]{{.ElementType}}, error) { - elementSlice := make([]{{.ElementType}}, len(slice)) - for i, v := range slice { - if _, err := {{setElement "elementSlice[i]" "v" .ElementType}}; err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []{{.ElementType}}, b []{{.ElementType}}) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -func SliceSliceEquals(a [][]{{.ElementType}}, b [][]{{.ElementType}}) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i],b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if err := SliceEquals(a[i],b[i]); err != nil { - return fmt.Errorf("at index %d: %w", i, err) - } - } - return nil -} - -func ElementToInterface(x *{{.ElementType}}) interface{} { - if i := x.BigInt(nil); i != nil { - return i - } - return x.Text(10) -} - -func ElementSliceToInterfaceSlice(x interface{}) []interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([]interface{}, X.Len()) - for i := range res { - xI := X.Index(i).Interface().({{.ElementType}}) - res[i] = ElementToInterface(&xI) - } - return res -} - -func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { - if x == nil { - return nil - } - - X := reflect.ValueOf(x) - - res := make([][]interface{}, X.Len()) - for i := range res { - res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) - } - - return res -} diff --git a/internal/generator/test_vector_utils/utils.go b/internal/generator/test_vector_utils/utils.go deleted file mode 100644 index 47b05c646..000000000 --- a/internal/generator/test_vector_utils/utils.go +++ /dev/null @@ -1,248 +0,0 @@ -package test_vector_utils - -/* -var hashCache = make(map[string]HashMap) - -func GetHash(path string) (HashMap, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if h, ok := hashCache[path]; ok { - return h, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var asMap map[string]interface{} - if err = json.Unmarshal(bytes, &asMap); err != nil { - return nil, err - } - - res := make(HashMap, 0, len(asMap)) - - for k, v := range asMap { - var entry RationalTriplet - if _, err = entry.value.SetInterface(v); err != nil { - return nil, err - } - - key := strings.Split(k, ",") - - switch len(key) { - case 1: - entry.key2Present = false - case 2: - entry.key2Present = true - if _, err = entry.key2.SetInterface(key[1]); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) - } - if _, err = entry.key1.SetInterface(key[0]); err != nil { - return nil, err - } - - res = append(res, &entry) - } - - res.sort() - - hashCache[path] = res - - return res, nil - - } else { - return nil, err - } -} - -func (m *HashMap) SaveUsedEntries(path string) error { - - var sb strings.Builder - sb.WriteRune('[') - - first := true - - for _, element := range *m { - if !element.used { - continue - } - if !first { - sb.WriteRune(',') - } - first = false - sb.WriteString("\n\t") - element.WriteKeyValue(&sb) - } - - if !first { - sb.WriteRune(',') - } - - sb.WriteString("\n]") - - return os.WriteFile(path, []byte(sb.String()), 0) -} - -type HashMap []*RationalTriplet - -type RationalTriplet struct { - key1 small_rational.SmallRational - key2 small_rational.SmallRational - key2Present bool - value small_rational.SmallRational - used bool -} - -func (t *RationalTriplet) WriteKeyValue(sb *strings.Builder) { - sb.WriteString("\t\"") - sb.WriteString(t.key1.String()) - if t.key2Present { - sb.WriteRune(',') - sb.WriteString(t.key2.String()) - } - sb.WriteString("\":") - if valueBytes, err := json.Marshal(ElementToInterface(&t.value)); err == nil { - sb.WriteString(string(valueBytes)) - } else { - panic(err.Error()) - } -} - -func (m *HashMap) sort() { - sort.Slice(*m, func(i, j int) bool { - return (*m)[i].CmpKey((*m)[j]) <= 0 - }) -} - -func (m *HashMap) find(toFind *RationalTriplet) small_rational.SmallRational { - i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) - - if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { - (*m)[i].used = true - return (*m)[i].value - } - - // if not found, add it: - if _, err := toFind.value.SetInterface(rand.Int63n(11) - 5); err != nil { - panic(err.Error()) - } - toFind.used = true - *m = append(*m, toFind) - m.sort() //Inefficient, but it's okay. This is only run when a new test case is introduced - - return toFind.value -} - -func (t *RationalTriplet) CmpKey(o *RationalTriplet) int { - if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { - return cmp1 - } - - if t.key2Present { - if o.key2Present { - return t.key2.Cmp(&o.key2) - } - return 1 - } else { - if o.key2Present { - return -1 - } - return 0 - } -} - -type MapHashTranscript struct { - HashMap HashMap - stateValid bool - resultAvailable bool - state small_rational.SmallRational -} - -func (m *HashMap) Hash(x *small_rational.SmallRational, y *small_rational.SmallRational) small_rational.SmallRational { - - toFind := RationalTriplet{ - key1: *x, - key2Present: y != nil, - } - - if y != nil { - toFind.key2 = *y - } - - return m.find(&toFind) -} - -func (m *MapHashTranscript) Update(i ...interface{}) { - if len(i) > 0 { - for _, x := range i { - - var xElement small_rational.SmallRational - if _, err := xElement.SetInterface(x); err != nil { - panic(err.Error()) - } - if m.stateValid { - m.state = m.HashMap.Hash(&xElement, &m.state) - } else { - m.state = m.HashMap.Hash(&xElement, nil) - } - - m.stateValid = true - } - } else { //just hash the state itself - if !m.stateValid { - panic("nothing to hash") - } - m.state = m.HashMap.Hash(&m.state, nil) - } - m.resultAvailable = true -} - -func (m *MapHashTranscript) Next(i ...interface{}) small_rational.SmallRational { - - if len(i) > 0 || !m.resultAvailable { - m.Update(i...) - } - m.resultAvailable = false - return m.state -} - -func (m *MapHashTranscript) NextN(N int, i ...interface{}) []small_rational.SmallRational { - - if len(i) > 0 { - m.Update(i...) - } - - res := make([]small_rational.SmallRational, N) - - for n := range res { - res[n] = m.Next() - } - - return res -} - -func SliceToElementSlice(slice []interface{}) ([]small_rational.SmallRational, error) { - elementSlice := make([]small_rational.SmallRational, len(slice)) - for i, v := range slice { - if _, err := elementSlice[i].SetInterface(v); err != nil { - return nil, err - } - } - return elementSlice, nil -} - -func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { - if len(a) != len(b) { - return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) - } - for i := range a { - if !a[i].Equal(&b[i]) { - return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) - } - } - return nil -} - -*/