diff --git a/internal/generator/backend/template/gkr/gate_testing.go.tmpl b/internal/generator/backend/template/gkr/gate_testing.go.tmpl index 8c78af347c..e6bfc24cff 100644 --- a/internal/generator/backend/template/gkr/gate_testing.go.tmpl +++ b/internal/generator/backend/template/gkr/gate_testing.go.tmpl @@ -15,6 +15,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -130,6 +131,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -139,11 +141,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 5105b0a33d..ec78c1c27b 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -98,7 +98,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []{{ .ElementType inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*{{ .ElementType }})) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*{{ .ElementType }})) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -230,7 +231,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]{{ .ElementType }}, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step {{ .ElementType }} + var ( + step {{ .ElementType }} + api gateAPI + ) res := make([]{{ .ElementType }}, degGJ) @@ -260,10 +264,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*{{ .ElementType }}) + summand := wire.Gate.Evaluate(&api, gateInput...).(*{{ .ElementType }}) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -663,6 +668,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]{{ .ElementType }}, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -720,70 +726,72 @@ func frToBigInts(dst []*big.Int, src []{{ .ElementType }}) { // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*{{ .ElementType }} + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod {{ .ElementType }} - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res {{ .ElementType }} - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x {{ .ElementType }} for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .ElementType }} { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -791,22 +799,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...{{ .ElementType }}) *{{ .E return f(api, inVar...).(*{{ .ElementType }}) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *{{ .ElementType }} { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new({{ .ElementType }})) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...{{ .ElementType }}) *{{ .ElementType }} // convertFunc turns f into a function that accepts and returns {{ .ElementType }}. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...{{ .ElementType }}) *{{ .ElementType }} { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *{{ .ElementType }} { +func (api *gateAPI) cast(v frontend.Variable) *{{ .ElementType }} { if x, ok := v.(*{{ .ElementType }}); ok { // fast path, no extra heap allocation return x } - var x {{ .ElementType }} + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/solver_hints.go.tmpl b/internal/generator/backend/template/gkr/solver_hints.go.tmpl index e1d41e8cb8..fbdec861da 100644 --- a/internal/generator/backend/template/gkr/solver_hints.go.tmpl +++ b/internal/generator/backend/template/gkr/solver_hints.go.tmpl @@ -68,6 +68,7 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } for instanceI := start; instanceI < end; instanceI++ { + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wireI := range data.circuit { wire := &data.circuit[wireI] deps := info.Dependencies[wireI] @@ -86,7 +87,8 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { inputs[i] = &data.assignment[inputI][instanceI] } gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) + data.assignment[wireI][instanceI].Set(gate.Evaluate(&api, inputs[:len(inputIndexes)]...).(*fr.Element)) + api.freeElements() } } } diff --git a/internal/gkr/bls12-377/gate_testing.go b/internal/gkr/bls12-377/gate_testing.go index 415a5ff5b3..71acc90b85 100644 --- a/internal/gkr/bls12-377/gate_testing.go +++ b/internal/gkr/bls12-377/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index f5dfad020e..5810251288 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -105,7 +105,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +270,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +673,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,70 +730,72 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +803,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls12-377/solver_hints.go b/internal/gkr/bls12-377/solver_hints.go index 39547cff29..cd105b361f 100644 --- a/internal/gkr/bls12-377/solver_hints.go +++ b/internal/gkr/bls12-377/solver_hints.go @@ -75,6 +75,7 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } for instanceI := start; instanceI < end; instanceI++ { + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wireI := range data.circuit { wire := &data.circuit[wireI] deps := info.Dependencies[wireI] @@ -93,7 +94,8 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { inputs[i] = &data.assignment[inputI][instanceI] } gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) + data.assignment[wireI][instanceI].Set(gate.Evaluate(&api, inputs[:len(inputIndexes)]...).(*fr.Element)) + api.freeElements() } } } diff --git a/internal/gkr/bls12-381/gate_testing.go b/internal/gkr/bls12-381/gate_testing.go index ef7694dc18..e96e9999f5 100644 --- a/internal/gkr/bls12-381/gate_testing.go +++ b/internal/gkr/bls12-381/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index f5617a59d4..cd5b427224 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -105,7 +105,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +270,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +673,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,70 +730,72 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +803,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls12-381/solver_hints.go b/internal/gkr/bls12-381/solver_hints.go index cb498c78b7..6027efda53 100644 --- a/internal/gkr/bls12-381/solver_hints.go +++ b/internal/gkr/bls12-381/solver_hints.go @@ -75,6 +75,7 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } for instanceI := start; instanceI < end; instanceI++ { + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wireI := range data.circuit { wire := &data.circuit[wireI] deps := info.Dependencies[wireI] @@ -93,7 +94,8 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { inputs[i] = &data.assignment[inputI][instanceI] } gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) + data.assignment[wireI][instanceI].Set(gate.Evaluate(&api, inputs[:len(inputIndexes)]...).(*fr.Element)) + api.freeElements() } } } diff --git a/internal/gkr/bls24-315/gate_testing.go b/internal/gkr/bls24-315/gate_testing.go index 1682d24771..0d0cd5e1e4 100644 --- a/internal/gkr/bls24-315/gate_testing.go +++ b/internal/gkr/bls24-315/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7d89baf7ef..70015a9055 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -105,7 +105,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +270,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +673,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,70 +730,72 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +803,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls24-315/solver_hints.go b/internal/gkr/bls24-315/solver_hints.go index 914c8a9d61..bd709ba749 100644 --- a/internal/gkr/bls24-315/solver_hints.go +++ b/internal/gkr/bls24-315/solver_hints.go @@ -75,6 +75,7 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } for instanceI := start; instanceI < end; instanceI++ { + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wireI := range data.circuit { wire := &data.circuit[wireI] deps := info.Dependencies[wireI] @@ -93,7 +94,8 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { inputs[i] = &data.assignment[inputI][instanceI] } gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) + data.assignment[wireI][instanceI].Set(gate.Evaluate(&api, inputs[:len(inputIndexes)]...).(*fr.Element)) + api.freeElements() } } } diff --git a/internal/gkr/bls24-317/gate_testing.go b/internal/gkr/bls24-317/gate_testing.go index 1bffab29e3..0c91c766c7 100644 --- a/internal/gkr/bls24-317/gate_testing.go +++ b/internal/gkr/bls24-317/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index fc9908b918..9dde9d38ae 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -105,7 +105,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +270,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +673,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,70 +730,72 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +803,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bls24-317/solver_hints.go b/internal/gkr/bls24-317/solver_hints.go index f6e1ad993d..4e17f93abb 100644 --- a/internal/gkr/bls24-317/solver_hints.go +++ b/internal/gkr/bls24-317/solver_hints.go @@ -75,6 +75,7 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } for instanceI := start; instanceI < end; instanceI++ { + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wireI := range data.circuit { wire := &data.circuit[wireI] deps := info.Dependencies[wireI] @@ -93,7 +94,8 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { inputs[i] = &data.assignment[inputI][instanceI] } gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) + data.assignment[wireI][instanceI].Set(gate.Evaluate(&api, inputs[:len(inputIndexes)]...).(*fr.Element)) + api.freeElements() } } } diff --git a/internal/gkr/bn254/gate_testing.go b/internal/gkr/bn254/gate_testing.go index 716ba3891b..299244e375 100644 --- a/internal/gkr/bn254/gate_testing.go +++ b/internal/gkr/bn254/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 04cf3512af..a509ace9fc 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -105,7 +105,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +270,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +673,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,70 +730,72 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +803,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bn254/solver_hints.go b/internal/gkr/bn254/solver_hints.go index 7bc3782932..28f955628e 100644 --- a/internal/gkr/bn254/solver_hints.go +++ b/internal/gkr/bn254/solver_hints.go @@ -75,6 +75,7 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } for instanceI := start; instanceI < end; instanceI++ { + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wireI := range data.circuit { wire := &data.circuit[wireI] deps := info.Dependencies[wireI] @@ -93,7 +94,8 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { inputs[i] = &data.assignment[inputI][instanceI] } gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) + data.assignment[wireI][instanceI].Set(gate.Evaluate(&api, inputs[:len(inputIndexes)]...).(*fr.Element)) + api.freeElements() } } } diff --git a/internal/gkr/bw6-633/gate_testing.go b/internal/gkr/bw6-633/gate_testing.go index 0fafa45a0d..a502e8f5b5 100644 --- a/internal/gkr/bw6-633/gate_testing.go +++ b/internal/gkr/bw6-633/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index cc1245e726..af4b468751 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -105,7 +105,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +270,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +673,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,70 +730,72 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +803,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bw6-633/solver_hints.go b/internal/gkr/bw6-633/solver_hints.go index 57343d291f..93f91aca92 100644 --- a/internal/gkr/bw6-633/solver_hints.go +++ b/internal/gkr/bw6-633/solver_hints.go @@ -75,6 +75,7 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } for instanceI := start; instanceI < end; instanceI++ { + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wireI := range data.circuit { wire := &data.circuit[wireI] deps := info.Dependencies[wireI] @@ -93,7 +94,8 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { inputs[i] = &data.assignment[inputI][instanceI] } gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) + data.assignment[wireI][instanceI].Set(gate.Evaluate(&api, inputs[:len(inputIndexes)]...).(*fr.Element)) + api.freeElements() } } } diff --git a/internal/gkr/bw6-761/gate_testing.go b/internal/gkr/bw6-761/gate_testing.go index 6eda2ebe73..ab95fd70d3 100644 --- a/internal/gkr/bw6-761/gate_testing.go +++ b/internal/gkr/bw6-761/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -118,6 +119,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -127,11 +129,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index f90f28114b..2d7c4205c5 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -105,7 +105,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []fr.Element, comb inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*fr.Element)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*fr.Element)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]fr.Element, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step fr.Element + var ( + step fr.Element + api gateAPI + ) res := make([]fr.Element, degGJ) @@ -266,10 +270,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*fr.Element) + summand := wire.Gate.Evaluate(&api, gateInput...).(*fr.Element) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +673,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]fr.Element, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,70 +730,72 @@ func frToBigInts(dst []*big.Int, src []fr.Element) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*fr.Element + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod fr.Element - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x fr.Element for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +803,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...fr.Element) *fr.Element { return f(api, inVar...).(*fr.Element) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *fr.Element { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(fr.Element)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...fr.Element) *fr.Element // convertFunc turns f into a function that accepts and returns fr.Element. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...fr.Element) *fr.Element { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *fr.Element { +func (api *gateAPI) cast(v frontend.Variable) *fr.Element { if x, ok := v.(*fr.Element); ok { // fast path, no extra heap allocation return x } - var x fr.Element + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x } diff --git a/internal/gkr/bw6-761/solver_hints.go b/internal/gkr/bw6-761/solver_hints.go index 606f13ec23..1988b242ec 100644 --- a/internal/gkr/bw6-761/solver_hints.go +++ b/internal/gkr/bw6-761/solver_hints.go @@ -75,6 +75,7 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { } for instanceI := start; instanceI < end; instanceI++ { + var api gateAPI // since the api is synchronous, we can't share it across Solve Hint invocations. for wireI := range data.circuit { wire := &data.circuit[wireI] deps := info.Dependencies[wireI] @@ -93,7 +94,8 @@ func SolveHint(info gkrtypes.SolvingInfo, data *SolvingData) hint.Hint { inputs[i] = &data.assignment[inputI][instanceI] } gate := data.circuit[wireI].Gate - data.assignment[wireI][instanceI].Set(gate.Evaluate(api, inputs[:len(inputIndexes)]...).(*fr.Element)) + data.assignment[wireI][instanceI].Set(gate.Evaluate(&api, inputs[:len(inputIndexes)]...).(*fr.Element)) + api.freeElements() } } } diff --git a/internal/gkr/small_rational/gate_testing.go b/internal/gkr/small_rational/gate_testing.go index 1817cfbf6f..f5a897c2ac 100644 --- a/internal/gkr/small_rational/gate_testing.go +++ b/internal/gkr/small_rational/gate_testing.go @@ -21,6 +21,7 @@ import ( // IsGateFunctionAdditive returns whether x_i occurs only in a monomial of total degree 1 in f func IsGateFunctionAdditive(f gkr.GateFunction, i, nbIn int) bool { + var api gateAPI fWrapped := api.convertFunc(f) // fix all variables except the i-th one at random points @@ -117,6 +118,7 @@ func (f gateFunctionFr) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynom // FindGateFunctionDegree returns the degree of the gate function, or -1 if it fails. // Failure could be due to the degree being higher than max or the function not being a polynomial at all. func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { + var api gateAPI fFr := api.convertFunc(f) bound := uint64(max) + 1 for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { @@ -126,11 +128,13 @@ func FindGateFunctionDegree(f gkr.GateFunction, max, nbIn int) (int, error) { } return len(p) - 1, nil } + api.freeElements() // not strictly necessary as few iterations are expected. } return -1, fmt.Errorf("could not find a degree: tried up to %d", max) } func VerifyGateFunctionDegree(f gkr.GateFunction, claimedDegree, nbIn int) error { + var api gateAPI fFr := api.convertFunc(f) if p := fFr.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { return fmt.Errorf("detected a higher degree than %d", claimedDegree) diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index e8e78f4b96..8bba4053e2 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -105,7 +105,8 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) verifyFinalEval(r []small_rational.S inputEvaluations[i] = &uniqueInputEvaluations[uniqueI] } - gateEvaluation.Set(wire.Gate.Evaluate(api, inputEvaluations...).(*small_rational.SmallRational)) + var api gateAPI + gateEvaluation.Set(wire.Gate.Evaluate(&api, inputEvaluations...).(*small_rational.SmallRational)) } evaluation.Mul(&evaluation, &gateEvaluation) @@ -236,7 +237,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { gJ := make([]small_rational.SmallRational, degGJ) var mu sync.Mutex computeAll := func(start, end int) { // compute method to allow parallelization across instances - var step small_rational.SmallRational + var ( + step small_rational.SmallRational + api gateAPI + ) res := make([]small_rational.SmallRational, degGJ) @@ -266,10 +270,11 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { for i := range gateInput { gateInput[i] = &mlEvals[eIndex+1+i] } - summand := wire.Gate.Evaluate(api, gateInput...).(*small_rational.SmallRational) + summand := wire.Gate.Evaluate(&api, gateInput...).(*small_rational.SmallRational) summand.Mul(summand, &mlEvals[eIndex]) res[d].Add(&res[d], summand) // collect contributions into the sum from start to end eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + api.freeElements() } } mu.Lock() @@ -668,6 +673,7 @@ func (a WireAssignment) Complete(wires gkrtypes.Wires) WireAssignment { } } + var api gateAPI ins := make([]small_rational.SmallRational, maxNbIns) for i := range nbInstances { for wI, w := range wires { @@ -724,70 +730,72 @@ func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { } // gateAPI implements gkr.GateAPI. -type gateAPI struct{} - -var api gateAPI +// It uses a synchronous memory pool underneath to minimize heap allocations. +type gateAPI struct { + allocated []*small_rational.SmallRational + nbUsed int +} -func (gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational // TODO Heap allocated. Keep an eye on perf - res.Add(cast(i1), cast(i2)) +func (api *gateAPI) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Add(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Add(&res, cast(v)) + res.Add(res, api.cast(v)) } - return &res + return res } -func (gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var prod small_rational.SmallRational - prod.Add(cast(b), cast(c)) - res := cast(a) - res.Add(res, &prod) - return &res +func (api *gateAPI) MulAcc(a, b, c frontend.Variable) frontend.Variable { + prod := api.newElement() + prod.Mul(api.cast(b), api.cast(c)) + res := api.cast(a) + res.Add(res, prod) + return res } -func (gateAPI) Neg(i1 frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Neg(cast(i1)) - return &res +func (api *gateAPI) Neg(i1 frontend.Variable) frontend.Variable { + res := api.newElement() + res.Neg(api.cast(i1)) + return res } -func (gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Sub(cast(i1), cast(i2)) +func (api *gateAPI) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Sub(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Sub(&res, cast(v)) + res.Sub(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res small_rational.SmallRational - res.Mul(cast(i1), cast(i2)) +func (api *gateAPI) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + res := api.newElement() + res.Mul(api.cast(i1), api.cast(i2)) for _, v := range in { - res.Mul(&res, cast(v)) + res.Mul(res, api.cast(v)) } - return &res + return res } -func (gateAPI) Println(a ...frontend.Variable) { +func (api *gateAPI) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) var x small_rational.SmallRational for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) } -func (api gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { +func (api *gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRational) *small_rational.SmallRational { inVar := make([]frontend.Variable, len(in)) for i := range in { inVar[i] = &in[i] @@ -795,22 +803,35 @@ func (api gateAPI) evaluate(f gkr.GateFunction, in ...small_rational.SmallRation return f(api, inVar...).(*small_rational.SmallRational) } +// Put all elements back in the pool. +func (api *gateAPI) freeElements() { + api.nbUsed = 0 +} + +func (api *gateAPI) newElement() *small_rational.SmallRational { + api.nbUsed++ + if api.nbUsed >= len(api.allocated) { + api.allocated = append(api.allocated, new(small_rational.SmallRational)) + } + return api.allocated[api.nbUsed-1] +} + type gateFunctionFr func(...small_rational.SmallRational) *small_rational.SmallRational // convertFunc turns f into a function that accepts and returns small_rational.SmallRational. -func (api gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { +func (api *gateAPI) convertFunc(f gkr.GateFunction) gateFunctionFr { return func(in ...small_rational.SmallRational) *small_rational.SmallRational { return api.evaluate(f, in...) } } -func cast(v frontend.Variable) *small_rational.SmallRational { +func (api *gateAPI) cast(v frontend.Variable) *small_rational.SmallRational { if x, ok := v.(*small_rational.SmallRational); ok { // fast path, no extra heap allocation return x } - var x small_rational.SmallRational + x := api.newElement() if _, err := x.SetInterface(v); err != nil { panic(err) } - return &x + return x }