Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/generator/backend/template/gkr/gate_testing.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

// IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f
func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool {
var api gateAPI
fWrapped := api.convertFunc(f)

// fix all variables except the i-th one at random points
Expand Down Expand Up @@ -130,6 +131,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom
// FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails.
// Failure could be due to the degree being higher than max or the function not being a polynomial at all.
func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) {
var api gateAPI
fFr := api.convertFunc(f)
bound := uint64(max) + 1
for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 {
Expand All @@ -139,11 +141,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) {
}
return len(p) - 1, nil
}
api.freeElements() // not strictly necessary as few iterations are expected.
}
return -1, fmt.Errorf("could not find a degree: tried up to %d", max)
}

func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error {
var api gateAPI
fFr := api.convertFunc(f)
if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil {
return fmt.Errorf("detected a higher degree than %d", claimedDegree)
Expand All @@ -157,6 +161,7 @@ func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error

// EqualGateFunction checks if two gate functions are equal, by testing the same at a random point.
func EqualGateFunction(f gkr.GateFunction, g gkr.GateFunction, nbIn int) bool {
var api gateAPI
x := make({{.FieldPackageName}}.Vector, nbIn)
x.MustSetRandom()
fFr := api.convertFunc(f)
Expand Down
96 changes: 58 additions & 38 deletions internal/generator/backend/template/gkr/gkr.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []{{ .ElementType
for i, uniqueI := range injectionLeftInv { // map from all to unique
inputEvaluations[i] = &uniqueInputEvaluations[uniqueI]
}

gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*{{ .ElementType }}))
var api gateAPI
gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*{{ .ElementType }}))
}

evaluation.Mul(&evaluation, &gateEvaluation)
Expand Down Expand Up @@ -230,7 +230,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial {
gJ := make([]{{ .ElementType }}, degGJ)
var mu sync.Mutex
computeAll := func(start, end int) { // compute method to allow parallelization across instances
var step {{ .ElementType }}
var (
step {{ .ElementType }}
api gateAPI
)

res := make([]{{ .ElementType }}, degGJ)

Expand Down Expand Up @@ -260,10 +263,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial {
for i := range gateInput {
gateInput[i] = &mlEvals[eIndex+1+i]
}
summand := wire.Gate.Evaluate(api, gateInput...).(*{{ .ElementType }})
summand := wire.Gate.Evaluate(&api, gateInput...).(*{{ .ElementType }})
summand.Mul(summand, &mlEvals[eIndex])
res[d].Add(&res[d], summand) // collect contributions into the sum from start to end
eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml)
api.freeElements()
}
}
mu.Lock()
Expand Down Expand Up @@ -663,6 +667,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment {
}
}

var api gateAPI
ins := make([]{{ .ElementType }}, maxNbIns)
for i := range nbInstances {
for wI, w := range wires {
Expand Down Expand Up @@ -720,52 +725,54 @@ func frToBigInts(dst []*big.Int, src []{{ .ElementType }}) {


// gateAPI implements gkr.GateAPI.
type gateAPI struct{}

var api gateAPI
// It uses a synchronous memory pool underneath to minimize heap allocations.
type gateAPI struct {
allocated []*{{ .ElementType }}
nbUsed int
}

func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
var res {{ .ElementType }} // TODO Heap allocated. Keep an eye on perf
res.Add(cast(i1), cast(i2))
func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
res := api.newElement()
res.Add(api.cast(i1), api.cast(i2))
for _, v := range in {
res.Add(&res, cast(v))
res.Add(res, api.cast(v))
}
return &res
return res
}

func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable {
var prod {{ .ElementType }}
prod.Mul(cast(b), cast(c))
res := cast(a)
res.Add(res, &prod)
return &res
func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable {
prod := api.newElement()
prod.Mul(api.cast(b), api.cast(c))
res := api.cast(a)
res.Add(res, prod)
return res
}

func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable {
var res {{ .ElementType }}
res.Neg(cast(i1))
return &res
func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable {
res := api.newElement()
res.Neg(api.cast(i1))
return res
}

func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
var res {{ .ElementType }}
res.Sub(cast(i1), cast(i2))
func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
res := api.newElement()
res.Sub(api.cast(i1), api.cast(i2))
for _, v := range in {
res.Sub(&res, cast(v))
res.Sub(res, api.cast(v))
}
return &res
return res
}

func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
var res {{ .ElementType }}
res.Mul(cast(i1), cast(i2))
func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
res := api.newElement()
res.Mul(api.cast(i1), api.cast(i2))
for _, v := range in {
res.Mul(&res, cast(v))
res.Mul(res, api.cast(v))
}
return &res
return res
}

func (gateAPI) Println(a ...frontend.Variable) {
func (api *gateAPI) Println(a ...frontend.Variable) {
toPrint := make([]any, len(a))
var x {{ .ElementType }}

Expand All @@ -783,30 +790,43 @@ func (gateAPI) Println(a ...frontend.Variable) {
fmt.Println(toPrint...)
}

func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} {
func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} {
inVar := make([]frontend.Variable, len(in))
for i := range in {
inVar[i] = &in[i]
}
return f(api, inVar...).(*{{ .ElementType }})
}

// Put all elements back in the pool.
func (api *gateAPI) freeElements() {
api.nbUsed = 0
}

func (api *gateAPI) newElement() *{{ .ElementType }} {
api.nbUsed++
if api.nbUsed >= len(api.allocated) {
api.allocated = append(api.allocated, new({{ .ElementType }}))
}
return api.allocated[api.nbUsed-1]
}

type gateFunctionFr func(...{{ .ElementType }}) *{{ .ElementType }}

// convertFunc turns f into a function that accepts and returns {{ .ElementType }}.
func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr {
func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr {
return func(in ...{{ .ElementType }}) *{{ .ElementType }} {
return api.evaluate(f, in...)
}
}

func cast(v frontend.Variable) *{{ .ElementType }} {
func (api *gateAPI) cast(v frontend.Variable) *{{ .ElementType }} {
if x, ok := v.(*{{ .ElementType }}); ok { // fast path, no extra heap allocation
return x
}
var x {{ .ElementType }}
x := api.newElement()
if _, err := x.SetInterface(v); err != nil {
panic(err)
}
return &x
return x
}
6 changes: 4 additions & 2 deletions internal/generator/backend/template/gkr/solver_hints.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func SolveHint(data *SolvingData) hint.Hint {

gateIns := make([]frontend.Variable, data.maxNbIn)
outsI := 0
insI := 1 // skip the first input, which is the instance index
insI := 1 // skip the first input, which is the instance index
var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations.
for wI := range data.circuit {
w := &data.circuit[wI]
if w.IsInput() { // read from provided input
Expand All @@ -110,7 +111,8 @@ func SolveHint(data *SolvingData) hint.Hint {
gateIns[i] = &data.assignment[inWI][instanceI]
}

data.assignment[wI][instanceI].Set(w.Gate.Evaluate(api, gateIns[:len(w.Inputs)]...).(*fr.Element))
data.assignment[wI][instanceI].Set(w.Gate.Evaluate(&api, gateIns[:len(w.Inputs)]...).(*fr.Element))
api.freeElements()
}
if w.IsOutput() {
data.assignment[wI][instanceI].BigInt(outs[outsI])
Expand Down
5 changes: 5 additions & 0 deletions internal/gkr/bls12-377/gate_testing.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading