From 1673f7043a73cff7afbdc1dcbd0466499484186e Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Wed, 7 May 2025 14:09:08 +0200 Subject: [PATCH 01/15] test commit --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 99c267de4..869eefa1c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ + # gnark-crypto [![Twitter URL](https://img.shields.io/twitter/url/https/twitter.com/gnark_team.svg?style=social&label=Follow%20%40gnark_team)](https://x.com/gnark_team) [![License](https://img.shields.io/badge/license-Apache%202-blue)](LICENSE) [![Go Report Card](https://goreportcard.com/badge/github.com/Consensys/gnark-crypto)](https://goreportcard.com/badge/github.com/Consensys/gnark-crypto) [![PkgGoDev](https://pkg.go.dev/badge/mod/github.com/consensys/gnark-crypto)](https://pkg.go.dev/mod/github.com/consensys/gnark-crypto) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5815453.svg)](https://doi.org/10.5281/zenodo.5815453) From fd672b97654785f3b0d3df29371d2b6c25f0ec7b Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Wed, 7 May 2025 20:25:27 +0200 Subject: [PATCH 02/15] feat: add fftext --- README.md | 1 - field/koalabear/extensions/e4.go | 9 + field/koalabear/fft/bitreverse.go | 28 ++ field/koalabear/fft/fftext.go | 406 ++++++++++++++++++++++++++ field/koalabear/fft/fftext_test.go | 414 +++++++++++++++++++++++++++ field/koalabear/fft/kernel_purego.go | 16 ++ 6 files changed, 873 insertions(+), 1 deletion(-) create mode 100644 field/koalabear/fft/fftext.go create mode 100644 field/koalabear/fft/fftext_test.go diff --git a/README.md b/README.md index 869eefa1c..99c267de4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,3 @@ - # gnark-crypto [![Twitter URL](https://img.shields.io/twitter/url/https/twitter.com/gnark_team.svg?style=social&label=Follow%20%40gnark_team)](https://x.com/gnark_team) [![License](https://img.shields.io/badge/license-Apache%202-blue)](LICENSE) [![Go Report Card](https://goreportcard.com/badge/github.com/Consensys/gnark-crypto)](https://goreportcard.com/badge/github.com/Consensys/gnark-crypto) [![PkgGoDev](https://pkg.go.dev/badge/mod/github.com/consensys/gnark-crypto)](https://pkg.go.dev/mod/github.com/consensys/gnark-crypto) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5815453.svg)](https://doi.org/10.5281/zenodo.5815453) diff --git a/field/koalabear/extensions/e4.go b/field/koalabear/extensions/e4.go index af45fedc8..fcab0827b 100644 --- a/field/koalabear/extensions/e4.go +++ b/field/koalabear/extensions/e4.go @@ -369,3 +369,12 @@ func MulAccE4(alpha *E4, scale []fr.Element, res []E4) { mulAccE4_avx512(alpha, &scale[0], &res[0], uint64(N)) } + +// Butterfly computes the butterfly operation on two E4 elements +func Butterfly(a, b *E4) { + fr.Butterfly(&a.B0.A0, &b.B0.A0) + fr.Butterfly(&a.B0.A1, &b.B0.A1) + + fr.Butterfly(&a.B1.A0, &b.B1.A0) + fr.Butterfly(&a.B1.A1, &b.B1.A1) +} diff --git a/field/koalabear/fft/bitreverse.go b/field/koalabear/fft/bitreverse.go index b1e953533..46c20fe47 100644 --- a/field/koalabear/fft/bitreverse.go +++ b/field/koalabear/fft/bitreverse.go @@ -11,6 +11,34 @@ import ( "github.com/consensys/gnark-crypto/field/koalabear" ) +// BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. +// It is a generic function that works on slices of any type T. +// Type T can be for example: koalabear.Element or fext.E4. +func BitReverseGeneric[T any](v []T) { + n := uint64(len(v)) + // Check if the length is a power of 2 using bit manipulation + if bits.OnesCount64(n) != 1 { + panic("len(v) must be a power of 2") + } + + // This is the naive bit-reversal algorithm (in-place swap) + // nn is used to calculate the significant bits for reversal + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + // Calculate the bit-reversed index + // bits.Reverse64 reverses the 64-bit representation of i + // We then right-shift by nn to get the reversal within the range [0, n-1] + iRev := bits.Reverse64(i) >> nn + + // Swap elements only if the reversed index is greater than the current index + // This prevents swapping elements twice (i -> iRev, then later iRev -> i) + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 func BitReverse(v []koalabear.Element) { diff --git a/field/koalabear/fft/fftext.go b/field/koalabear/fft/fftext.go new file mode 100644 index 000000000..60c45fa6a --- /dev/null +++ b/field/koalabear/fft/fftext.go @@ -0,0 +1,406 @@ +package fft + +import ( + "math/big" + "math/bits" + + "github.com/consensys/gnark-crypto/ecc" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" + "github.com/consensys/gnark-crypto/internal/parallel" + + "github.com/consensys/gnark-crypto/field/koalabear" +) + +// FFTExt computes the discrete Fourier transform of a slice of extension field elements. +// Coefficients and evaluations are extension field elements. +// The root of unity domain is the same as FFT, they are koalabear elements. +func (domain *Domain) FFTExt(a []fext.E4, decimation Decimation, opts ...Option) { + + opt := fftOptions(opts) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } + + // if coset != 0, scale by coset table + if opt.coset { + + if decimation == DIT { + + // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]koalabear.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].MulByElement(&a[i], &cosetTable[irev]) + } + }, opt.nbTasks) + } else { + + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at koalabear.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + + } + } + + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + if nbStages-twiddlesStartStage > 0 { + twiddles = make([][]koalabear.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]koalabear.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1<> nn) + a[i].MulByElement(&a[i], &cosetTableInv[irev]). + MulByElement(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) + +} + +func difFFTExt(a []fext.E4, w koalabear.Element, twiddles [][]koalabear.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + if n == 1<<8 { + kerDIFNP_256Ext(a, twiddles, stage-twiddlesStartStage) + return + } + } + m := n >> 1 + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at koalabear.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddlesExt(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) + } else { + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + +} + +func innerDIFWithTwiddlesGenericExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddlesExt(a []fext.E4, at, w koalabear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFTExt(a []fext.E4, w koalabear.Element, twiddles [][]koalabear.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + if n == 1<<8 { + kerDITNP_256Ext(a, twiddles, stage-twiddlesStartStage) + return + } + } + + m := n >> 1 + + nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) + + if stage < maxSplits { + // that's the only time we fire go routines + chDone := make(chan struct{}, 1) + + go ditFFTExt(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFTExt(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at koalabear.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + + } else { + innerDITWithoutTwiddlesExt(a, w, w, 0, m, m) + } + return + } + + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) +} + +func innerDITWithTwiddlesGenericExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + fext.Butterfly(&a[i], &a[i+m]) + + } + +} + +func innerDITWithoutTwiddlesExt(a []fext.E4, at, w koalabear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &at) + fext.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } +} + +func kerDIFNP_256genericExt(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + innerDIFWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fext.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256genericExt(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fext.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddlesGenericExt(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddlesGenericExt(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddlesGenericExt(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddlesGenericExt(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddlesGenericExt(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddlesGenericExt(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) +} diff --git a/field/koalabear/fft/fftext_test.go b/field/koalabear/fft/fftext_test.go new file mode 100644 index 000000000..c31dc151a --- /dev/null +++ b/field/koalabear/fft/fftext_test.go @@ -0,0 +1,414 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "strconv" + "testing" + + "math/big" + + "fmt" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "encoding/binary" + + "math/rand/v2" + + "github.com/stretchr/testify/require" + + "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" +) + +func TestFFTExt(t *testing.T) { + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 6 + properties := gopter.NewProperties(parameters) + + for maxSize := 2; maxSize <= 1<<10; maxSize <<= 1 { + + domainWithPrecompute := NewDomain(uint64(maxSize)) + domainWithoutPrecompute := NewDomain(uint64(maxSize), WithoutPrecompute()) + + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF) + BitReverseGeneric(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF, OnCoset()) + BitReverseGeneric(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverseGeneric(pol) + domain.FFTExt(pol, DIT) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverseGeneric(pol) + domain.FFTExt(pol, DIT) + domain.FFTInverseExt(pol, DIF) + BitReverseGeneric(pol) + + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) + + for nbCosets := 2; nbCosets < 5; nbCosets++ { + properties.Property(fmt.Sprintf("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on %d cosets", nbCosets), prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + check := true + + for i := 1; i <= nbCosets; i++ { + + BitReverseGeneric(pol) + domain.FFTExt(pol, DIT, OnCoset()) + domain.FFTInverseExt(pol, DIF, OnCoset()) + BitReverseGeneric(pol) + + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + } + + return check + }, + )) + } + + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF) + domain.FFTExt(pol, DIT) + + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) + + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF, OnCoset()) + domain.FFTExt(pol, DIT, OnCoset()) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + // compute with nbTasks == 1 + domain.FFTInverseExt(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFTExt(pol, DIT, OnCoset(), WithNbTasks(1)) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + return true + }, + )) + } + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } + +} + +func randElementExt(rng *rand.Rand) fext.E4 { + var v fext.E4 + v.B0.A0 = koalabear.Element{rng.Uint32N(2130706433)} + v.B0.A1 = koalabear.Element{rng.Uint32N(2130706433)} + v.B1.A0 = koalabear.Element{rng.Uint32N(2130706433)} + v.B1.A1 = koalabear.Element{rng.Uint32N(2130706433)} + return v +} + +func FuzzFFTExt(f *testing.F) { + f.Fuzz(func(t *testing.T, domainSize uint16, rngSeed int64) { + if domainSize > (1 << 13) { + t.Skip("domain size too large") + } + if domainSize < 2 { + t.Skip("domain size too small") + } + + domain := NewDomain(uint64(domainSize)) + + var seed [32]byte + binary.PutVarint(seed[:], rngSeed) + // #nosec G404 -- fuzz does not require a cryptographic PRNG + rng := rand.New(rand.NewChaCha8(seed)) + + cardinality := domain.Cardinality + + // we just check that FFT-1(FFT(pol)) == pol + a, b := make([]fext.E4, cardinality), make([]fext.E4, cardinality) + for i := 0; i < int(cardinality); i++ { + a[i] = randElementExt(rng) + } + copy(b, a) + + domain.FFTInverseExt(a, DIF) + domain.FFTExt(a, DIT) + + assert := require.New(t) + for i := 0; i < int(cardinality); i++ { + assert.True(a[i].Equal(&b[i]), "FFT-1(FFT(pol)) != pol at index %d", i) + } + }) +} + +// -------------------------------------------------------------------- +// benches + +func BenchmarkFFTExt(b *testing.B) { + + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + for i := 8; i < 20; i++ { + sizeDomain := 1 << i + b.Run("fft 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT) + } + }) + b.Run("fft 2**"+strconv.Itoa(i)+"bits (coset)", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT, OnCoset()) + } + }) + } + +} + +func BenchmarkFFTDITCosetReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIT, OnCoset()) + } +} + +func BenchmarkFFTDITReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < 1; j++ { + domain.FFTExt(pol, DIT) + } +} + +func BenchmarkFFTDIFReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} +func BenchmarkFFTDIFReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} + +func evaluatePolynomialExt(pol []fext.E4, val koalabear.Element) fext.E4 { + var res, tmp fext.E4 + var acc koalabear.Element + res.Set(&pol[0]) + acc.Set(&val) + for i := 1; i < len(pol); i++ { + tmp.MulByElement(&pol[i], &acc) + res.Add(&res, &tmp) + acc.Mul(&acc, &val) + } + return res +} diff --git a/field/koalabear/fft/kernel_purego.go b/field/koalabear/fft/kernel_purego.go index c711d7d72..ac907352e 100644 --- a/field/koalabear/fft/kernel_purego.go +++ b/field/koalabear/fft/kernel_purego.go @@ -9,6 +9,7 @@ package fft import ( "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" ) func innerDIFWithTwiddles(a []koalabear.Element, twiddles []koalabear.Element, start, end, m int) { @@ -25,3 +26,18 @@ func kerDIFNP_256(a []koalabear.Element, twiddles [][]koalabear.Element, stage i func kerDITNP_256(a []koalabear.Element, twiddles [][]koalabear.Element, stage int) { kerDITNP_256generic(a, twiddles, stage) } + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func kerDIFNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + kerDIFNP_256genericExt(a, twiddles, stage) +} +func kerDITNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + kerDITNP_256genericExt(a, twiddles, stage) +} From 85370847513979170192f3b20e07fd57a28125d3 Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Thu, 8 May 2025 22:08:38 +0200 Subject: [PATCH 03/15] fix: add butterfly for E4 in e4.go.tmpl --- field/babybear/extensions/e4.go | 9 +++++++++ .../internal/templates/extensions/e4.go.tmpl | 12 +++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/field/babybear/extensions/e4.go b/field/babybear/extensions/e4.go index 957710e3e..30d4d2115 100644 --- a/field/babybear/extensions/e4.go +++ b/field/babybear/extensions/e4.go @@ -369,3 +369,12 @@ func MulAccE4(alpha *E4, scale []fr.Element, res []E4) { mulAccE4_avx512(alpha, &scale[0], &res[0], uint64(N)) } + +// Butterfly computes the butterfly operation on two E4 elements +func Butterfly(a, b *E4) { + fr.Butterfly(&a.B0.A0, &b.B0.A0) + fr.Butterfly(&a.B0.A1, &b.B0.A1) + + fr.Butterfly(&a.B1.A0, &b.B1.A0) + fr.Butterfly(&a.B1.A1, &b.B1.A1) +} diff --git a/field/generator/internal/templates/extensions/e4.go.tmpl b/field/generator/internal/templates/extensions/e4.go.tmpl index 50c56fee5..5aba79ddb 100644 --- a/field/generator/internal/templates/extensions/e4.go.tmpl +++ b/field/generator/internal/templates/extensions/e4.go.tmpl @@ -366,4 +366,14 @@ func MulAccE4(alpha *E4, scale []fr.Element, res []E4) { mulAccE4_avx512(alpha, &scale[0], &res[0], uint64(N)) } -{{- end}} \ No newline at end of file +{{- end}} + + +// Butterfly computes the butterfly operation on two E4 elements +func Butterfly(a, b *E4) { + fr.Butterfly(&a.B0.A0, &b.B0.A0) + fr.Butterfly(&a.B0.A1, &b.B0.A1) + + fr.Butterfly(&a.B1.A0, &b.B1.A0) + fr.Butterfly(&a.B1.A1, &b.B1.A1) +} From e2db2a432abd06052bd2ef0c423312fcdc87295e Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Thu, 8 May 2025 23:06:43 +0200 Subject: [PATCH 04/15] fix: add BitReverseGeneric and SmallFiled interface in bitreverse.go.tmpl --- field/babybear/fft/bitreverse.go | 33 +++++++++++++++ .../internal/templates/extensions/e4.go.tmpl | 1 - .../internal/templates/fft/bitreverse.go.tmpl | 40 ++++++++++++++++++- field/koalabear/fft/bitreverse.go | 7 +++- 4 files changed, 78 insertions(+), 3 deletions(-) diff --git a/field/babybear/fft/bitreverse.go b/field/babybear/fft/bitreverse.go index 6c376c898..a753d7023 100644 --- a/field/babybear/fft/bitreverse.go +++ b/field/babybear/fft/bitreverse.go @@ -9,8 +9,41 @@ import ( "math/bits" "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" ) +type SmallFiled interface { + babybear.Element | fext.E4 +} + +// BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. +// It is a generic function that works on slices of any type T. +// Type T can be for example: koalabear.Element or fext.E4. +func BitReverseGeneric[T SmallFiled](v []T) { + n := uint64(len(v)) + // Check if the length is a power of 2 using bit manipulation + if bits.OnesCount64(n) != 1 { + panic("len(v) must be a power of 2") + } + + // This is the naive bit-reversal algorithm (in-place swap) + // nn is used to calculate the significant bits for reversal + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + // Calculate the bit-reversed index + // bits.Reverse64 reverses the 64-bit representation of i + // We then right-shift by nn to get the reversal within the range [0, n-1] + iRev := bits.Reverse64(i) >> nn + + // Swap elements only if the reversed index is greater than the current index + // This prevents swapping elements twice (i -> iRev, then later iRev -> i) + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} + // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 func BitReverse(v []babybear.Element) { diff --git a/field/generator/internal/templates/extensions/e4.go.tmpl b/field/generator/internal/templates/extensions/e4.go.tmpl index 5aba79ddb..0671f9683 100644 --- a/field/generator/internal/templates/extensions/e4.go.tmpl +++ b/field/generator/internal/templates/extensions/e4.go.tmpl @@ -368,7 +368,6 @@ func MulAccE4(alpha *E4, scale []fr.Element, res []E4) { } {{- end}} - // Butterfly computes the butterfly operation on two E4 elements func Butterfly(a, b *E4) { fr.Butterfly(&a.B0.A0, &b.B0.A0) diff --git a/field/generator/internal/templates/fft/bitreverse.go.tmpl b/field/generator/internal/templates/fft/bitreverse.go.tmpl index c99e3451f..06e4c1697 100644 --- a/field/generator/internal/templates/fft/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/bitreverse.go.tmpl @@ -6,8 +6,45 @@ import ( {{- end}} "{{ .FieldPackagePath }}" + {{- if .F31}} + fext "{{ .FieldPackagePath }}/extensions" + {{- end}} ) +{{- if .F31}} +type SmallFiled interface { + {{ .FF }}.Element | fext.E4 +} + +// BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. +// It is a generic function that works on slices of any type T. +// Type T can be for example: koalabear.Element or fext.E4. +func BitReverseGeneric[T SmallFiled](v []T) { + n := uint64(len(v)) + // Check if the length is a power of 2 using bit manipulation + if bits.OnesCount64(n) != 1 { + panic("len(v) must be a power of 2") + } + + // This is the naive bit-reversal algorithm (in-place swap) + // nn is used to calculate the significant bits for reversal + nn := uint64(64 - bits.TrailingZeros64(n)) + + for i := uint64(0); i < n; i++ { + // Calculate the bit-reversed index + // bits.Reverse64 reverses the 64-bit representation of i + // We then right-shift by nn to get the reversal within the range [0, n-1] + iRev := bits.Reverse64(i) >> nn + + // Swap elements only if the reversed index is greater than the current index + // This prevents swapping elements twice (i -> iRev, then later iRev -> i) + if iRev > i { + v[i], v[iRev] = v[iRev], v[i] + } + } +} +{{- end}} + // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 func BitReverse(v []{{ .FF }}.Element) { @@ -244,4 +281,5 @@ func bitReverseCobraInPlace_{{.logTileSize}}_{{.logN}}(v []{{ .FF }}.Element) { } -{{- end}} \ No newline at end of file +{{- end}} + diff --git a/field/koalabear/fft/bitreverse.go b/field/koalabear/fft/bitreverse.go index 46c20fe47..0cd921c54 100644 --- a/field/koalabear/fft/bitreverse.go +++ b/field/koalabear/fft/bitreverse.go @@ -9,12 +9,17 @@ import ( "math/bits" "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" ) +type SmallFiled interface { + koalabear.Element | fext.E4 +} + // BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. // It is a generic function that works on slices of any type T. // Type T can be for example: koalabear.Element or fext.E4. -func BitReverseGeneric[T any](v []T) { +func BitReverseGeneric[T SmallFiled](v []T) { n := uint64(len(v)) // Check if the length is a power of 2 using bit manipulation if bits.OnesCount64(n) != 1 { From 5e6dd6d5296c38bb2b0ffeaa013f157cd5315b30 Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Thu, 8 May 2025 23:11:59 +0200 Subject: [PATCH 05/15] fix: add ext funcs in kernel.purego.go.tmpl --- .../templates/fft/kernel.purego.go.tmpl | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/field/generator/internal/templates/fft/kernel.purego.go.tmpl b/field/generator/internal/templates/fft/kernel.purego.go.tmpl index 99d75c3d9..35cce050c 100644 --- a/field/generator/internal/templates/fft/kernel.purego.go.tmpl +++ b/field/generator/internal/templates/fft/kernel.purego.go.tmpl @@ -1,5 +1,8 @@ import ( "{{ .FieldPackagePath }}" + {{- if .F31}} + fext "{{ .FieldPackagePath }}/extensions" + {{- end}} ) func innerDIFWithTwiddles(a []{{ .FF }}.Element, twiddles []{{ .FF }}.Element, start, end, m int) { @@ -18,4 +21,23 @@ func kerDIFNP_{{$ksize}}(a []{{ $.FF }}.Element, twiddles [][]{{ $.FF }}.Element func kerDITNP_{{$ksize}}(a []{{ $.FF }}.Element, twiddles [][]{{ $.FF }}.Element, stage int) { kerDITNP_{{$ksize}}generic(a, twiddles, stage) } -{{end}} \ No newline at end of file +{{end}} + +{{- if .F31}} +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) +} +{{range $ki, $klog2 := $.Kernels}} + {{- $ksize := shl 1 $klog2}} +func kerDIFNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ .FF }}.Element, stage int) { + kerDIFNP_{{$ksize}}genericExt(a, twiddles, stage) +} +func kerDITNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ .FF }}.Element, stage int) { + kerDITNP_{{$ksize}}genericExt(a, twiddles, stage) +} +{{end}} +{{- end}} \ No newline at end of file From 12b44655780ea2b730b6bdece6bb9e8723eb43a3 Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Fri, 9 May 2025 00:04:50 +0200 Subject: [PATCH 06/15] fix: add tmpl for fftext --- .../internal/templates/fft/fftext.go .tmpl | 361 ++++++++++++++++ .../templates/fft/tests/fftext.go.tmpl | 399 ++++++++++++++++++ field/koalabear/fft/fftext.go | 4 +- field/koalabear/fft/fftext_test.go | 3 +- 4 files changed, 763 insertions(+), 4 deletions(-) create mode 100644 field/generator/internal/templates/fft/fftext.go .tmpl create mode 100644 field/generator/internal/templates/fft/tests/fftext.go.tmpl diff --git a/field/generator/internal/templates/fft/fftext.go .tmpl b/field/generator/internal/templates/fft/fftext.go .tmpl new file mode 100644 index 000000000..2c8fe52d3 --- /dev/null +++ b/field/generator/internal/templates/fft/fftext.go .tmpl @@ -0,0 +1,361 @@ +import ( + "math/big" + "math/bits" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/internal/parallel" + + "{{ .FieldPackagePath }}" + fext "{{ .FieldPackagePath }}/extensions" +) + + +// FFTExt computes the discrete Fourier transform of a slice of extension field elements. +// Coefficients and evaluations are extension field elements. +// The root of unity domain is the same as FFT. +func (domain *Domain) FFTExt(a []fext.E4, decimation Decimation, opts ...Option) { + + opt := fftOptions(opts) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } + + // if coset != 0, scale by coset table + if opt.coset { + + if decimation == DIT { + + // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]{{ .FF }}.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].MulByElement(&a[i], &cosetTable[irev]) + } + }, opt.nbTasks) + } else { + + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at {{ .FF }}.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + + } + } + + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + if nbStages-twiddlesStartStage > 0 { + twiddles = make([][]{{ .FF }}.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]{{ .FF }}.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1<> nn) + a[i].MulByElement(&a[i], &cosetTableInv[irev]). + MulByElement(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) + +} + + +func difFFTExt(a []fext.E4, w {{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + {{- range $ki, $klog2 := $.Kernels}} + {{- if ne $ki 0}} else {{- end}} if n == 1 << {{$klog2}} { + {{- $ksize := shl 1 $klog2}} + kerDIFNP_{{$ksize}}Ext(a, twiddles, stage-twiddlesStartStage) + return + } + {{- end }} + } + m := n >> 1 + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at {{ .FF }}.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddlesExt(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) + } else { + {{- if .HasASMKernel}} + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + {{- else}} + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks / (1 << (stage))) + } else { + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } + {{- end}} + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + +} + +func innerDIFWithTwiddlesGenericExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddlesExt(a []fext.E4, at, w {{ .FF }}.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFTExt(a []fext.E4, w {{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + {{- range $ki, $klog2 := $.Kernels}} + {{- if ne $ki 0}} else {{- end}} if n == 1 << {{$klog2}} { + {{- $ksize := shl 1 $klog2}} + kerDITNP_{{$ksize}}Ext(a, twiddles, stage-twiddlesStartStage) + return + } + {{- end }} + } + + m := n >> 1 + + nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) + + if stage < maxSplits { + // that's the only time we fire go routines + chDone := make(chan struct{}, 1) + + go ditFFTExt(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFTExt(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at {{ .FF }}.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + + } else { + innerDITWithoutTwiddlesExt(a, w, w, 0, m, m) + } + return + } + {{- if .HasASMKernel}} + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + {{- else}} + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks / (1 << (stage))) + } else { + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } + {{- end}} +} diff --git a/field/generator/internal/templates/fft/tests/fftext.go.tmpl b/field/generator/internal/templates/fft/tests/fftext.go.tmpl new file mode 100644 index 000000000..f6d707442 --- /dev/null +++ b/field/generator/internal/templates/fft/tests/fftext.go.tmpl @@ -0,0 +1,399 @@ +import ( + "math/big" + "testing" + "strconv" + + + "{{ .FieldPackagePath }}" + fext "{{ .FieldPackagePath }}/extensions" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" + "github.com/leanovate/gopter/gen" + + "fmt" + {{- if .F31}} + "github.com/stretchr/testify/require" + "encoding/binary" + "math/rand/v2" + {{- end}} +) + +func TestFFTExt(t *testing.T) { + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 6 + properties := gopter.NewProperties(parameters) + + for maxSize := 2; maxSize <= 1<<10; maxSize <<= 1 { + + domainWithPrecompute := NewDomain(uint64(maxSize)) + domainWithoutPrecompute := NewDomain(uint64(maxSize), WithoutPrecompute()) + + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF) + BitReverseGeneric(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF, OnCoset()) + BitReverseGeneric(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverseGeneric(pol) + domain.FFTExt(pol, DIT) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverseGeneric(pol) + domain.FFTExt(pol, DIT) + domain.FFTInverseExt(pol, DIF) + BitReverseGeneric(pol) + + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) + + for nbCosets := 2; nbCosets < 5; nbCosets++ { + properties.Property(fmt.Sprintf("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on %d cosets", nbCosets), prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + check := true + + for i := 1; i <= nbCosets; i++ { + + BitReverseGeneric(pol) + domain.FFTExt(pol, DIT, OnCoset()) + domain.FFTInverseExt(pol, DIF, OnCoset()) + BitReverseGeneric(pol) + + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + } + + return check + }, + )) + } + + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF) + domain.FFTExt(pol, DIT) + + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) + + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF, OnCoset()) + domain.FFTExt(pol, DIT, OnCoset()) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + // compute with nbTasks == 1 + domain.FFTInverseExt(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFTExt(pol, DIT, OnCoset(), WithNbTasks(1)) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + return true + }, + )) + } + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } + +} +{{- if .F31}} + +func randElementExt(rng *rand.Rand) fext.E4 { + var v fext.E4 + v.B0.A0 = {{ .FF }}.Element{rng.Uint32N({{.Q}})} + v.B0.A1 = {{ .FF }}.Element{rng.Uint32N({{.Q}})} + v.B1.A0 = {{ .FF }}.Element{rng.Uint32N({{.Q}})} + v.B1.A1 = {{ .FF }}.Element{rng.Uint32N({{.Q}})} + return v +} + +func FuzzFFTExt(f *testing.F) { + f.Fuzz(func(t *testing.T, domainSize uint16, rngSeed int64) { + if domainSize > (1 << 13) { + t.Skip("domain size too large") + } + if domainSize < 2 { + t.Skip("domain size too small") + } + + domain := NewDomain(uint64(domainSize)) + + var seed [32]byte + binary.PutVarint(seed[:], rngSeed) + // #nosec G404 -- fuzz does not require a cryptographic PRNG + rng := rand.New(rand.NewChaCha8(seed)) + + cardinality := domain.Cardinality + + // we just check that FFT-1(FFT(pol)) == pol + a, b := make([]fext.E4, cardinality), make([]fext.E4, cardinality) + for i := 0; i < int(cardinality); i++ { + a[i] = randElementExt(rng) + } + copy(b, a) + + domain.FFTInverseExt(a, DIF) + domain.FFTExt(a, DIT) + + assert := require.New(t) + for i := 0; i < int(cardinality); i++ { + assert.True(a[i].Equal(&b[i]), "FFT-1(FFT(pol)) != pol at index %d", i) + } + }) +} + +{{- end}} + +// -------------------------------------------------------------------- +// benches + +func BenchmarkFFTExt(b *testing.B) { + + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + for i := 8; i < 20; i++ { + sizeDomain := 1 << i + b.Run("fft 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT) + } + }) + b.Run("fft 2**"+strconv.Itoa(i)+"bits (coset)", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT, OnCoset()) + } + }) + } + +} + +func BenchmarkFFTDITCosetReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIT, OnCoset()) + } +} + +func BenchmarkFFTDITReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < 1; j++ { + domain.FFTExt(pol, DIT) + } +} + +func BenchmarkFFTDIFReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} +func BenchmarkFFTDIFReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} + +func evaluatePolynomialExt(pol []fext.E4, val {{ .FF }}.Element) fext.E4 { + var res, tmp fext.E4 + var acc {{ .FF }}.Element + res.Set(&pol[0]) + acc.Set(&val) + for i := 1; i < len(pol); i++ { + tmp.MulByElement(&pol[i], &acc) + res.Add(&res, &tmp) + acc.Mul(&acc, &val) + } + return res +} diff --git a/field/koalabear/fft/fftext.go b/field/koalabear/fft/fftext.go index 60c45fa6a..eb17dbeb8 100644 --- a/field/koalabear/fft/fftext.go +++ b/field/koalabear/fft/fftext.go @@ -5,15 +5,15 @@ import ( "math/bits" "github.com/consensys/gnark-crypto/ecc" - fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" "github.com/consensys/gnark-crypto/internal/parallel" "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" ) // FFTExt computes the discrete Fourier transform of a slice of extension field elements. // Coefficients and evaluations are extension field elements. -// The root of unity domain is the same as FFT, they are koalabear elements. +// The root of unity domain is the same as FFT. func (domain *Domain) FFTExt(a []fext.E4, decimation Decimation, opts ...Option) { opt := fftOptions(opts) diff --git a/field/koalabear/fft/fftext_test.go b/field/koalabear/fft/fftext_test.go index c31dc151a..8a5fe392e 100644 --- a/field/koalabear/fft/fftext_test.go +++ b/field/koalabear/fft/fftext_test.go @@ -17,11 +17,10 @@ package fft import ( + "math/big" "strconv" "testing" - "math/big" - "fmt" "github.com/leanovate/gopter" "github.com/leanovate/gopter/gen" From 7da3849157554f804457cc5494e3f594ed81183a Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Fri, 9 May 2025 13:36:02 +0200 Subject: [PATCH 07/15] refactor: bitreverse, step1 --- field/koalabear/fft/bitreverse.go | 8 +- field/koalabear/fft/bitreverse_test.go | 125 +++++++++++++++++++++---- field/koalabear/fft/fftext_test.go | 14 +-- 3 files changed, 116 insertions(+), 31 deletions(-) diff --git a/field/koalabear/fft/bitreverse.go b/field/koalabear/fft/bitreverse.go index 0cd921c54..9ec22b662 100644 --- a/field/koalabear/fft/bitreverse.go +++ b/field/koalabear/fft/bitreverse.go @@ -12,14 +12,14 @@ import ( fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" ) -type SmallFiled interface { +type SmallField interface { koalabear.Element | fext.E4 } // BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. // It is a generic function that works on slices of any type T. // Type T can be for example: koalabear.Element or fext.E4. -func BitReverseGeneric[T SmallFiled](v []T) { +func BitReverseGeneric[T SmallField](v []T) { n := uint64(len(v)) // Check if the length is a power of 2 using bit manipulation if bits.OnesCount64(n) != 1 { @@ -46,7 +46,7 @@ func BitReverseGeneric[T SmallFiled](v []T) { // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func BitReverse(v []koalabear.Element) { +func BitReverse[T SmallField](v []T) { n := uint64(len(v)) if bits.OnesCount64(n) != 1 { panic("len(a) must be a power of 2") @@ -57,7 +57,7 @@ func BitReverse(v []koalabear.Element) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func bitReverseNaive(v []koalabear.Element) { +func bitReverseNaive[T SmallField](v []T) { n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) diff --git a/field/koalabear/fft/bitreverse_test.go b/field/koalabear/fft/bitreverse_test.go index d1dd3cfcf..8be29d771 100644 --- a/field/koalabear/fft/bitreverse_test.go +++ b/field/koalabear/fft/bitreverse_test.go @@ -10,22 +10,24 @@ import ( "testing" "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" + ) -type bitReverseVariant struct { +type bitReverseVariant[T SmallField] struct { name string - buf []koalabear.Element - fn func([]koalabear.Element) + buf []T + fn func([]T) } const maxSizeBitReverse = 1 << 23 - -var bitReverse = []bitReverseVariant{ - {name: "bitReverseNaive", buf: make([]koalabear.Element, maxSizeBitReverse), fn: bitReverseNaive}, - {name: "BitReverse", buf: make([]koalabear.Element, maxSizeBitReverse), fn: BitReverse}, +var koalabearBitReverse = []bitReverseVariant[koalabear.Element]{ + {name: "bitReverseNaive", buf: make([]koalabear.Element, maxSizeBitReverse), fn: bitReverseNaive[koalabear.Element]}, + {name: "BitReverse", buf: make([]koalabear.Element, maxSizeBitReverse), fn: BitReverse[koalabear.Element]}, } -func TestBitReverse(t *testing.T) { + +func TestKoalabearBitReverse(t *testing.T) { // generate a random []koalabear.Element array of size 2**20 pol := make([]koalabear.Element, maxSizeBitReverse) @@ -39,33 +41,33 @@ func TestBitReverse(t *testing.T) { for size := 2; size <= maxSizeBitReverse; size <<= 1 { // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { copy(data.buf, pol[:size]) } // compute bit reverse shuffling - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { data.fn(data.buf[:size]) } // all bitReverse.buf should hold the same result for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(koalabearBitReverse); j++ { + if !koalabearBitReverse[0].buf[i].Equal(&koalabearBitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", koalabearBitReverse[0].name, koalabearBitReverse[j].name) } } } // bitReverse back should be identity - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { data.fn(data.buf[:size]) } for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(koalabearBitReverse); j++ { + if !koalabearBitReverse[0].buf[i].Equal(&koalabearBitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", koalabearBitReverse[0].name, koalabearBitReverse[j].name) } } } @@ -73,7 +75,7 @@ func TestBitReverse(t *testing.T) { } -func BenchmarkBitReverse(b *testing.B) { +func BenchmarkKoalabearBitReverse(b *testing.B) { // generate a random []koalabear.Element array of size 2**22 pol := make([]koalabear.Element, maxSizeBitReverse) one := koalabear.One() @@ -83,13 +85,96 @@ func BenchmarkBitReverse(b *testing.B) { } // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range koalabearBitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} + +var e4BitReverse = []bitReverseVariant[fext.E4]{ + {name: "bitReverseNaive", buf: make([]fext.E4, maxSizeBitReverse), fn: bitReverseNaive[fext.E4]}, + {name: "BitReverse", buf: make([]fext.E4, maxSizeBitReverse), fn: BitReverse[fext.E4]}, +} + + +func TestE4BitReverse(t *testing.T) { + + // generate a random []koalabear.Element array of size 2**20 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkE4BitReverse(b *testing.B) { + // generate a random []koalabear.Element array of size 2**22 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range e4BitReverse { copy(data.buf, pol[:maxSizeBitReverse]) } // benchmark for each size, each bitReverse function for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { - for _, data := range bitReverse { + for _, data := range e4BitReverse { b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { b.ResetTimer() for j := 0; j < b.N; j++ { diff --git a/field/koalabear/fft/fftext_test.go b/field/koalabear/fft/fftext_test.go index 8a5fe392e..21343e4b5 100644 --- a/field/koalabear/fft/fftext_test.go +++ b/field/koalabear/fft/fftext_test.go @@ -67,7 +67,7 @@ func TestFFTExt(t *testing.T) { copy(backupPol, pol) domain.FFTExt(pol, DIF) - BitReverseGeneric(pol) + BitReverse(pol) sample := domain.Generator sample.Exp(sample, big.NewInt(int64(ithpower))) @@ -94,7 +94,7 @@ func TestFFTExt(t *testing.T) { copy(backupPol, pol) domain.FFTExt(pol, DIF, OnCoset()) - BitReverseGeneric(pol) + BitReverse(pol) sample := domain.Generator sample.Exp(sample, big.NewInt(int64(ithpower))). @@ -121,7 +121,7 @@ func TestFFTExt(t *testing.T) { } copy(backupPol, pol) - BitReverseGeneric(pol) + BitReverse(pol) domain.FFTExt(pol, DIT) sample := domain.Generator @@ -147,10 +147,10 @@ func TestFFTExt(t *testing.T) { } copy(backupPol, pol) - BitReverseGeneric(pol) + BitReverse(pol) domain.FFTExt(pol, DIT) domain.FFTInverseExt(pol, DIF) - BitReverseGeneric(pol) + BitReverse(pol) check := true for i := 0; i < len(pol); i++ { @@ -177,10 +177,10 @@ func TestFFTExt(t *testing.T) { for i := 1; i <= nbCosets; i++ { - BitReverseGeneric(pol) + BitReverse(pol) domain.FFTExt(pol, DIT, OnCoset()) domain.FFTInverseExt(pol, DIF, OnCoset()) - BitReverseGeneric(pol) + BitReverse(pol) for i := 0; i < len(pol); i++ { check = check && pol[i].Equal(&backupPol[i]) From 56ffc4623946e1d4210e618e815ddd7366f101ed Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Fri, 9 May 2025 13:40:22 +0200 Subject: [PATCH 08/15] refactor: bitreverse, step1 --- .../internal/templates/fft/bitreverse.go.tmpl | 4 +- field/koalabear/fft/bitreverse.go | 4 +- field/koalabear/fft/bitreverse_test.go | 125 +++--------------- 3 files changed, 24 insertions(+), 109 deletions(-) diff --git a/field/generator/internal/templates/fft/bitreverse.go.tmpl b/field/generator/internal/templates/fft/bitreverse.go.tmpl index 06e4c1697..86c21c867 100644 --- a/field/generator/internal/templates/fft/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/bitreverse.go.tmpl @@ -12,14 +12,14 @@ import ( ) {{- if .F31}} -type SmallFiled interface { +type SmallField interface { {{ .FF }}.Element | fext.E4 } // BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. // It is a generic function that works on slices of any type T. // Type T can be for example: koalabear.Element or fext.E4. -func BitReverseGeneric[T SmallFiled](v []T) { +func BitReverseGeneric[T SmallField](v []T) { n := uint64(len(v)) // Check if the length is a power of 2 using bit manipulation if bits.OnesCount64(n) != 1 { diff --git a/field/koalabear/fft/bitreverse.go b/field/koalabear/fft/bitreverse.go index 9ec22b662..e86aa9924 100644 --- a/field/koalabear/fft/bitreverse.go +++ b/field/koalabear/fft/bitreverse.go @@ -46,7 +46,7 @@ func BitReverseGeneric[T SmallField](v []T) { // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func BitReverse[T SmallField](v []T) { +func BitReverse(v []koalabear.Element) { n := uint64(len(v)) if bits.OnesCount64(n) != 1 { panic("len(a) must be a power of 2") @@ -57,7 +57,7 @@ func BitReverse[T SmallField](v []T) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func bitReverseNaive[T SmallField](v []T) { +func bitReverseNaive(v []koalabear.Element) { n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) diff --git a/field/koalabear/fft/bitreverse_test.go b/field/koalabear/fft/bitreverse_test.go index 8be29d771..d1dd3cfcf 100644 --- a/field/koalabear/fft/bitreverse_test.go +++ b/field/koalabear/fft/bitreverse_test.go @@ -10,24 +10,22 @@ import ( "testing" "github.com/consensys/gnark-crypto/field/koalabear" - fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" - ) -type bitReverseVariant[T SmallField] struct { +type bitReverseVariant struct { name string - buf []T - fn func([]T) + buf []koalabear.Element + fn func([]koalabear.Element) } const maxSizeBitReverse = 1 << 23 -var koalabearBitReverse = []bitReverseVariant[koalabear.Element]{ - {name: "bitReverseNaive", buf: make([]koalabear.Element, maxSizeBitReverse), fn: bitReverseNaive[koalabear.Element]}, - {name: "BitReverse", buf: make([]koalabear.Element, maxSizeBitReverse), fn: BitReverse[koalabear.Element]}, -} +var bitReverse = []bitReverseVariant{ + {name: "bitReverseNaive", buf: make([]koalabear.Element, maxSizeBitReverse), fn: bitReverseNaive}, + {name: "BitReverse", buf: make([]koalabear.Element, maxSizeBitReverse), fn: BitReverse}, +} -func TestKoalabearBitReverse(t *testing.T) { +func TestBitReverse(t *testing.T) { // generate a random []koalabear.Element array of size 2**20 pol := make([]koalabear.Element, maxSizeBitReverse) @@ -41,33 +39,33 @@ func TestKoalabearBitReverse(t *testing.T) { for size := 2; size <= maxSizeBitReverse; size <<= 1 { // copy pol into the buffers - for _, data := range koalabearBitReverse { + for _, data := range bitReverse { copy(data.buf, pol[:size]) } // compute bit reverse shuffling - for _, data := range koalabearBitReverse { + for _, data := range bitReverse { data.fn(data.buf[:size]) } // all bitReverse.buf should hold the same result for i := 0; i < size; i++ { - for j := 1; j < len(koalabearBitReverse); j++ { - if !koalabearBitReverse[0].buf[i].Equal(&koalabearBitReverse[j].buf[i]) { - t.Fatalf("bitReverse %s and %s do not compute the same result", koalabearBitReverse[0].name, koalabearBitReverse[j].name) + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) } } } // bitReverse back should be identity - for _, data := range koalabearBitReverse { + for _, data := range bitReverse { data.fn(data.buf[:size]) } for i := 0; i < size; i++ { - for j := 1; j < len(koalabearBitReverse); j++ { - if !koalabearBitReverse[0].buf[i].Equal(&koalabearBitReverse[j].buf[i]) { - t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", koalabearBitReverse[0].name, koalabearBitReverse[j].name) + for j := 1; j < len(bitReverse); j++ { + if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) } } } @@ -75,7 +73,7 @@ func TestKoalabearBitReverse(t *testing.T) { } -func BenchmarkKoalabearBitReverse(b *testing.B) { +func BenchmarkBitReverse(b *testing.B) { // generate a random []koalabear.Element array of size 2**22 pol := make([]koalabear.Element, maxSizeBitReverse) one := koalabear.One() @@ -85,96 +83,13 @@ func BenchmarkKoalabearBitReverse(b *testing.B) { } // copy pol into the buffers - for _, data := range koalabearBitReverse { - copy(data.buf, pol[:maxSizeBitReverse]) - } - - // benchmark for each size, each bitReverse function - for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { - for _, data := range koalabearBitReverse { - b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { - b.ResetTimer() - for j := 0; j < b.N; j++ { - data.fn(data.buf[:size]) - } - }) - } - } -} - -var e4BitReverse = []bitReverseVariant[fext.E4]{ - {name: "bitReverseNaive", buf: make([]fext.E4, maxSizeBitReverse), fn: bitReverseNaive[fext.E4]}, - {name: "BitReverse", buf: make([]fext.E4, maxSizeBitReverse), fn: BitReverse[fext.E4]}, -} - - -func TestE4BitReverse(t *testing.T) { - - // generate a random []koalabear.Element array of size 2**20 - pol := make([]fext.E4, maxSizeBitReverse) - var one fext.E4 - one.SetOne() - pol[0].MustSetRandom() - for i := 1; i < maxSizeBitReverse; i++ { - pol[i].Add(&pol[i-1], &one) - } - - // for each size, check that all the bitReverse functions fn compute the same result. - for size := 2; size <= maxSizeBitReverse; size <<= 1 { - - // copy pol into the buffers - for _, data := range e4BitReverse { - copy(data.buf, pol[:size]) - } - - // compute bit reverse shuffling - for _, data := range e4BitReverse { - data.fn(data.buf[:size]) - } - - // all bitReverse.buf should hold the same result - for i := 0; i < size; i++ { - for j := 1; j < len(e4BitReverse); j++ { - if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { - t.Fatalf("bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) - } - } - } - - // bitReverse back should be identity - for _, data := range e4BitReverse { - data.fn(data.buf[:size]) - } - - for i := 0; i < size; i++ { - for j := 1; j < len(e4BitReverse); j++ { - if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { - t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) - } - } - } - } - -} - -func BenchmarkE4BitReverse(b *testing.B) { - // generate a random []koalabear.Element array of size 2**22 - pol := make([]fext.E4, maxSizeBitReverse) - var one fext.E4 - one.SetOne() - pol[0].MustSetRandom() - for i := 1; i < maxSizeBitReverse; i++ { - pol[i].Add(&pol[i-1], &one) - } - - // copy pol into the buffers - for _, data := range e4BitReverse { + for _, data := range bitReverse { copy(data.buf, pol[:maxSizeBitReverse]) } // benchmark for each size, each bitReverse function for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { - for _, data := range e4BitReverse { + for _, data := range bitReverse { b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { b.ResetTimer() for j := 0; j < b.N; j++ { From 72b92403649b47c9c2a28198c9fd81265e75267f Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Fri, 9 May 2025 14:00:22 +0200 Subject: [PATCH 09/15] refactor: update bitreverse.go.tmpl --- .../internal/templates/fft/bitreverse.go.tmpl | 36 +++++-------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/field/generator/internal/templates/fft/bitreverse.go.tmpl b/field/generator/internal/templates/fft/bitreverse.go.tmpl index 86c21c867..f4dee6d7c 100644 --- a/field/generator/internal/templates/fft/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/bitreverse.go.tmpl @@ -15,39 +15,15 @@ import ( type SmallField interface { {{ .FF }}.Element | fext.E4 } - -// BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. -// It is a generic function that works on slices of any type T. -// Type T can be for example: koalabear.Element or fext.E4. -func BitReverseGeneric[T SmallField](v []T) { - n := uint64(len(v)) - // Check if the length is a power of 2 using bit manipulation - if bits.OnesCount64(n) != 1 { - panic("len(v) must be a power of 2") - } - - // This is the naive bit-reversal algorithm (in-place swap) - // nn is used to calculate the significant bits for reversal - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - // Calculate the bit-reversed index - // bits.Reverse64 reverses the 64-bit representation of i - // We then right-shift by nn to get the reversal within the range [0, n-1] - iRev := bits.Reverse64(i) >> nn - - // Swap elements only if the reversed index is greater than the current index - // This prevents swapping elements twice (i -> iRev, then later iRev -> i) - if iRev > i { - v[i], v[iRev] = v[iRev], v[i] - } - } -} {{- end}} // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 +{{ if .F31}} +func BitReverse[T SmallField](v []T) { +{{- else}} func BitReverse(v []{{ .FF }}.Element) { +{{- end}} n := uint64(len(v)) if bits.OnesCount64(n) != 1 { panic("len(a) must be a power of 2") @@ -66,7 +42,11 @@ func BitReverse(v []{{ .FF }}.Element) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 +{{ if .F31}} +func bitReverseNaive[T SmallField](v []T) { +{{- else}} func bitReverseNaive(v []{{ .FF }}.Element) { +{{- end}} n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) From c7e8ecac8782b4634e6388ef0307af6ac28b695a Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Fri, 9 May 2025 14:03:47 +0200 Subject: [PATCH 10/15] refactor: update bitreverse.go.tmpl --- field/babybear/fft/bitreverse.go | 30 +---------------- .../internal/templates/fft/bitreverse.go.tmpl | 12 +++---- field/koalabear/fft/bitreverse.go | 32 ++----------------- 3 files changed, 7 insertions(+), 67 deletions(-) diff --git a/field/babybear/fft/bitreverse.go b/field/babybear/fft/bitreverse.go index a753d7023..f59436bf2 100644 --- a/field/babybear/fft/bitreverse.go +++ b/field/babybear/fft/bitreverse.go @@ -12,38 +12,10 @@ import ( fext "github.com/consensys/gnark-crypto/field/babybear/extensions" ) -type SmallFiled interface { +type SmallField interface { babybear.Element | fext.E4 } -// BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. -// It is a generic function that works on slices of any type T. -// Type T can be for example: koalabear.Element or fext.E4. -func BitReverseGeneric[T SmallFiled](v []T) { - n := uint64(len(v)) - // Check if the length is a power of 2 using bit manipulation - if bits.OnesCount64(n) != 1 { - panic("len(v) must be a power of 2") - } - - // This is the naive bit-reversal algorithm (in-place swap) - // nn is used to calculate the significant bits for reversal - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - // Calculate the bit-reversed index - // bits.Reverse64 reverses the 64-bit representation of i - // We then right-shift by nn to get the reversal within the range [0, n-1] - iRev := bits.Reverse64(i) >> nn - - // Swap elements only if the reversed index is greater than the current index - // This prevents swapping elements twice (i -> iRev, then later iRev -> i) - if iRev > i { - v[i], v[iRev] = v[iRev], v[i] - } - } -} - // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 func BitReverse(v []babybear.Element) { diff --git a/field/generator/internal/templates/fft/bitreverse.go.tmpl b/field/generator/internal/templates/fft/bitreverse.go.tmpl index f4dee6d7c..47f9ea75a 100644 --- a/field/generator/internal/templates/fft/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/bitreverse.go.tmpl @@ -19,10 +19,8 @@ type SmallField interface { // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 -{{ if .F31}} -func BitReverse[T SmallField](v []T) { -{{- else}} -func BitReverse(v []{{ .FF }}.Element) { +{{ if .F31}}func BitReverse[T SmallField](v []T) { +{{- else}}func BitReverse(v []{{ .FF }}.Element) { {{- end}} n := uint64(len(v)) if bits.OnesCount64(n) != 1 { @@ -42,10 +40,8 @@ func BitReverse(v []{{ .FF }}.Element) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 -{{ if .F31}} -func bitReverseNaive[T SmallField](v []T) { -{{- else}} -func bitReverseNaive(v []{{ .FF }}.Element) { +{{ if .F31}}func bitReverseNaive[T SmallField](v []T) { +{{- else}}func bitReverseNaive(v []{{ .FF }}.Element) { {{- end}} n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) diff --git a/field/koalabear/fft/bitreverse.go b/field/koalabear/fft/bitreverse.go index e86aa9924..7a57ac0ca 100644 --- a/field/koalabear/fft/bitreverse.go +++ b/field/koalabear/fft/bitreverse.go @@ -16,37 +16,9 @@ type SmallField interface { koalabear.Element | fext.E4 } -// BitReverseGeneric applies the bit-reversal permutation to the elements of slice v. -// It is a generic function that works on slices of any type T. -// Type T can be for example: koalabear.Element or fext.E4. -func BitReverseGeneric[T SmallField](v []T) { - n := uint64(len(v)) - // Check if the length is a power of 2 using bit manipulation - if bits.OnesCount64(n) != 1 { - panic("len(v) must be a power of 2") - } - - // This is the naive bit-reversal algorithm (in-place swap) - // nn is used to calculate the significant bits for reversal - nn := uint64(64 - bits.TrailingZeros64(n)) - - for i := uint64(0); i < n; i++ { - // Calculate the bit-reversed index - // bits.Reverse64 reverses the 64-bit representation of i - // We then right-shift by nn to get the reversal within the range [0, n-1] - iRev := bits.Reverse64(i) >> nn - - // Swap elements only if the reversed index is greater than the current index - // This prevents swapping elements twice (i -> iRev, then later iRev -> i) - if iRev > i { - v[i], v[iRev] = v[iRev], v[i] - } - } -} - // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func BitReverse(v []koalabear.Element) { +func BitReverse[T SmallField](v []T) { n := uint64(len(v)) if bits.OnesCount64(n) != 1 { panic("len(a) must be a power of 2") @@ -57,7 +29,7 @@ func BitReverse(v []koalabear.Element) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func bitReverseNaive(v []koalabear.Element) { +func bitReverseNaive[T SmallField](v []T) { n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) From 1f4664262b1a93e73d02a42124d2a8c625c1efc4 Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Fri, 9 May 2025 14:13:56 +0200 Subject: [PATCH 11/15] refactor: fftext_test tmpl --- .../internal/templates/fft/tests/fftext.go.tmpl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/field/generator/internal/templates/fft/tests/fftext.go.tmpl b/field/generator/internal/templates/fft/tests/fftext.go.tmpl index f6d707442..c46becb62 100644 --- a/field/generator/internal/templates/fft/tests/fftext.go.tmpl +++ b/field/generator/internal/templates/fft/tests/fftext.go.tmpl @@ -50,7 +50,7 @@ func TestFFTExt(t *testing.T) { copy(backupPol, pol) domain.FFTExt(pol, DIF) - BitReverseGeneric(pol) + BitReverse(pol) sample := domain.Generator sample.Exp(sample, big.NewInt(int64(ithpower))) @@ -77,7 +77,7 @@ func TestFFTExt(t *testing.T) { copy(backupPol, pol) domain.FFTExt(pol, DIF, OnCoset()) - BitReverseGeneric(pol) + BitReverse(pol) sample := domain.Generator sample.Exp(sample, big.NewInt(int64(ithpower))). @@ -104,7 +104,7 @@ func TestFFTExt(t *testing.T) { } copy(backupPol, pol) - BitReverseGeneric(pol) + BitReverse(pol) domain.FFTExt(pol, DIT) sample := domain.Generator @@ -130,10 +130,10 @@ func TestFFTExt(t *testing.T) { } copy(backupPol, pol) - BitReverseGeneric(pol) + BitReverse(pol) domain.FFTExt(pol, DIT) domain.FFTInverseExt(pol, DIF) - BitReverseGeneric(pol) + BitReverse(pol) check := true for i := 0; i < len(pol); i++ { @@ -160,10 +160,10 @@ func TestFFTExt(t *testing.T) { for i := 1; i <= nbCosets; i++ { - BitReverseGeneric(pol) + BitReverse(pol) domain.FFTExt(pol, DIT, OnCoset()) domain.FFTInverseExt(pol, DIF, OnCoset()) - BitReverseGeneric(pol) + BitReverse(pol) for i := 0; i < len(pol); i++ { check = check && pol[i].Equal(&backupPol[i]) From 0c49320a033f584ddce5fdc4e001f82aa3581895 Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Fri, 9 May 2025 14:45:15 +0200 Subject: [PATCH 12/15] refactor: bitreverse_test tmpl --- .../templates/fft/tests/bitreverse.go.tmpl | 177 +++++++++++++++++- field/koalabear/fft/bitreverse_test.go | 121 ++++++++++-- 2 files changed, 278 insertions(+), 20 deletions(-) diff --git a/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl b/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl index dd67bcada..5a47f1b35 100644 --- a/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl @@ -3,9 +3,12 @@ import ( "testing" "{{ .FieldPackagePath }}" + {{- if .F31}} + fext "{{ .FieldPackagePath }}/extensions" + {{- end}} ) - +{{- if not .F31}} type bitReverseVariant struct { name string buf []{{ .FF }}.Element @@ -98,3 +101,175 @@ func BenchmarkBitReverse(b *testing.B) { } } } +{{- else}} +type bitReverseVariant[T SmallField] struct { + name string + buf []T + fn func([]T) +} + +const maxSizeBitReverse = 1 << 23 +var {{ .FF }}BitReverse = []bitReverseVariant[{{ .FF }}.Element]{ + {name: "bitReverseNaive", buf: make([]{{ .FF }}.Element, maxSizeBitReverse), fn: bitReverseNaive[{{ .FF }}.Element]}, + {name: "BitReverse", buf: make([]{{ .FF }}.Element, maxSizeBitReverse), fn: BitReverse[{{ .FF }}.Element]}, +} + + +func TestElementBitReverse(t *testing.T) { + + // generate a random []{{ .FF }}.Element array of size 2**20 + pol := make([]{{ .FF }}.Element, maxSizeBitReverse) + one := {{ .FF }}.One() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range {{ .FF }}BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range {{ .FF }}BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len({{ .FF }}BitReverse); j++ { + if !{{ .FF }}BitReverse[0].buf[i].Equal(&{{ .FF }}BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", {{ .FF }}BitReverse[0].name, {{ .FF }}BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range {{ .FF }}BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len({{ .FF }}BitReverse); j++ { + if !{{ .FF }}BitReverse[0].buf[i].Equal(&{{ .FF }}BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", {{ .FF }}BitReverse[0].name, {{ .FF }}BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkElementBitReverse(b *testing.B) { + // generate a random []{{ .FF }}.Element array of size 2**22 + pol := make([]{{ .FF }}.Element, maxSizeBitReverse) + one := {{ .FF }}.One() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range {{ .FF }}BitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range {{ .FF }}BitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} + +var e4BitReverse = []bitReverseVariant[fext.E4]{ + {name: "bitReverseNaive", buf: make([]fext.E4, maxSizeBitReverse), fn: bitReverseNaive[fext.E4]}, + {name: "BitReverse", buf: make([]fext.E4, maxSizeBitReverse), fn: BitReverse[fext.E4]}, +} + + +func TestE4BitReverse(t *testing.T) { + + // generate a random []{{ .FF }}.Element array of size 2**20 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkE4BitReverse(b *testing.B) { + // generate a random []E4 array of size 2**22 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range e4BitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} +{{- end}} diff --git a/field/koalabear/fft/bitreverse_test.go b/field/koalabear/fft/bitreverse_test.go index d1dd3cfcf..31f2ad018 100644 --- a/field/koalabear/fft/bitreverse_test.go +++ b/field/koalabear/fft/bitreverse_test.go @@ -10,22 +10,23 @@ import ( "testing" "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" ) -type bitReverseVariant struct { +type bitReverseVariant[T SmallField] struct { name string - buf []koalabear.Element - fn func([]koalabear.Element) + buf []T + fn func([]T) } const maxSizeBitReverse = 1 << 23 -var bitReverse = []bitReverseVariant{ - {name: "bitReverseNaive", buf: make([]koalabear.Element, maxSizeBitReverse), fn: bitReverseNaive}, - {name: "BitReverse", buf: make([]koalabear.Element, maxSizeBitReverse), fn: BitReverse}, +var koalabearBitReverse = []bitReverseVariant[koalabear.Element]{ + {name: "bitReverseNaive", buf: make([]koalabear.Element, maxSizeBitReverse), fn: bitReverseNaive[koalabear.Element]}, + {name: "BitReverse", buf: make([]koalabear.Element, maxSizeBitReverse), fn: BitReverse[koalabear.Element]}, } -func TestBitReverse(t *testing.T) { +func TestElementBitReverse(t *testing.T) { // generate a random []koalabear.Element array of size 2**20 pol := make([]koalabear.Element, maxSizeBitReverse) @@ -39,33 +40,33 @@ func TestBitReverse(t *testing.T) { for size := 2; size <= maxSizeBitReverse; size <<= 1 { // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { copy(data.buf, pol[:size]) } // compute bit reverse shuffling - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { data.fn(data.buf[:size]) } // all bitReverse.buf should hold the same result for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(koalabearBitReverse); j++ { + if !koalabearBitReverse[0].buf[i].Equal(&koalabearBitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", koalabearBitReverse[0].name, koalabearBitReverse[j].name) } } } // bitReverse back should be identity - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { data.fn(data.buf[:size]) } for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(koalabearBitReverse); j++ { + if !koalabearBitReverse[0].buf[i].Equal(&koalabearBitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", koalabearBitReverse[0].name, koalabearBitReverse[j].name) } } } @@ -73,7 +74,7 @@ func TestBitReverse(t *testing.T) { } -func BenchmarkBitReverse(b *testing.B) { +func BenchmarkElementBitReverse(b *testing.B) { // generate a random []koalabear.Element array of size 2**22 pol := make([]koalabear.Element, maxSizeBitReverse) one := koalabear.One() @@ -83,13 +84,95 @@ func BenchmarkBitReverse(b *testing.B) { } // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { copy(data.buf, pol[:maxSizeBitReverse]) } // benchmark for each size, each bitReverse function for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} + +var e4BitReverse = []bitReverseVariant[fext.E4]{ + {name: "bitReverseNaive", buf: make([]fext.E4, maxSizeBitReverse), fn: bitReverseNaive[fext.E4]}, + {name: "BitReverse", buf: make([]fext.E4, maxSizeBitReverse), fn: BitReverse[fext.E4]}, +} + +func TestE4BitReverse(t *testing.T) { + + // generate a random []koalabear.Element array of size 2**20 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkE4BitReverse(b *testing.B) { + // generate a random []E4 array of size 2**22 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range e4BitReverse { b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { b.ResetTimer() for j := 0; j < b.N; j++ { From f6bc8f9118646377264a8a9d7596f225ece8d098 Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Fri, 9 May 2025 20:11:23 +0200 Subject: [PATCH 13/15] fix: add fftext path in generator_fft.go --- field/babybear/fft/bitreverse.go | 4 +- field/babybear/fft/bitreverse_test.go | 121 +++++- field/babybear/fft/fftext.go | 407 ++++++++++++++++++ field/babybear/fft/fftext_test.go | 400 +++++++++++++++++ field/babybear/fft/kernel_purego.go | 16 + field/generator/generator_fft.go | 5 +- .../internal/templates/fft/bitreverse.go.tmpl | 3 +- .../fft/{fftext.go .tmpl => fftext.go.tmpl} | 92 +++- .../templates/fft/kernel.purego.go.tmpl | 4 +- .../templates/fft/tests/bitreverse.go.tmpl | 3 +- .../templates/fft/tests/fftext.go.tmpl | 2 +- field/koalabear/fft/fftext.go | 13 +- field/koalabear/fft/fftext_test.go | 27 +- 13 files changed, 1042 insertions(+), 55 deletions(-) create mode 100644 field/babybear/fft/fftext.go create mode 100644 field/babybear/fft/fftext_test.go rename field/generator/internal/templates/fft/{fftext.go .tmpl => fftext.go.tmpl} (81%) diff --git a/field/babybear/fft/bitreverse.go b/field/babybear/fft/bitreverse.go index f59436bf2..d604ba7ca 100644 --- a/field/babybear/fft/bitreverse.go +++ b/field/babybear/fft/bitreverse.go @@ -18,7 +18,7 @@ type SmallField interface { // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func BitReverse(v []babybear.Element) { +func BitReverse[T SmallField](v []T) { n := uint64(len(v)) if bits.OnesCount64(n) != 1 { panic("len(a) must be a power of 2") @@ -29,7 +29,7 @@ func BitReverse(v []babybear.Element) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func bitReverseNaive(v []babybear.Element) { +func bitReverseNaive[T SmallField](v []T) { n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) diff --git a/field/babybear/fft/bitreverse_test.go b/field/babybear/fft/bitreverse_test.go index 96aa6f323..6d093091a 100644 --- a/field/babybear/fft/bitreverse_test.go +++ b/field/babybear/fft/bitreverse_test.go @@ -10,22 +10,23 @@ import ( "testing" "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" ) -type bitReverseVariant struct { +type bitReverseVariant[T SmallField] struct { name string - buf []babybear.Element - fn func([]babybear.Element) + buf []T + fn func([]T) } const maxSizeBitReverse = 1 << 23 -var bitReverse = []bitReverseVariant{ - {name: "bitReverseNaive", buf: make([]babybear.Element, maxSizeBitReverse), fn: bitReverseNaive}, - {name: "BitReverse", buf: make([]babybear.Element, maxSizeBitReverse), fn: BitReverse}, +var babybearBitReverse = []bitReverseVariant[babybear.Element]{ + {name: "bitReverseNaive", buf: make([]babybear.Element, maxSizeBitReverse), fn: bitReverseNaive[babybear.Element]}, + {name: "BitReverse", buf: make([]babybear.Element, maxSizeBitReverse), fn: BitReverse[babybear.Element]}, } -func TestBitReverse(t *testing.T) { +func TestElementBitReverse(t *testing.T) { // generate a random []babybear.Element array of size 2**20 pol := make([]babybear.Element, maxSizeBitReverse) @@ -39,33 +40,33 @@ func TestBitReverse(t *testing.T) { for size := 2; size <= maxSizeBitReverse; size <<= 1 { // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range babybearBitReverse { copy(data.buf, pol[:size]) } // compute bit reverse shuffling - for _, data := range bitReverse { + for _, data := range babybearBitReverse { data.fn(data.buf[:size]) } // all bitReverse.buf should hold the same result for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(babybearBitReverse); j++ { + if !babybearBitReverse[0].buf[i].Equal(&babybearBitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", babybearBitReverse[0].name, babybearBitReverse[j].name) } } } // bitReverse back should be identity - for _, data := range bitReverse { + for _, data := range babybearBitReverse { data.fn(data.buf[:size]) } for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(babybearBitReverse); j++ { + if !babybearBitReverse[0].buf[i].Equal(&babybearBitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", babybearBitReverse[0].name, babybearBitReverse[j].name) } } } @@ -73,7 +74,7 @@ func TestBitReverse(t *testing.T) { } -func BenchmarkBitReverse(b *testing.B) { +func BenchmarkElementBitReverse(b *testing.B) { // generate a random []babybear.Element array of size 2**22 pol := make([]babybear.Element, maxSizeBitReverse) one := babybear.One() @@ -83,13 +84,95 @@ func BenchmarkBitReverse(b *testing.B) { } // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range babybearBitReverse { copy(data.buf, pol[:maxSizeBitReverse]) } // benchmark for each size, each bitReverse function for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { - for _, data := range bitReverse { + for _, data := range babybearBitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} + +var e4BitReverse = []bitReverseVariant[fext.E4]{ + {name: "bitReverseNaive", buf: make([]fext.E4, maxSizeBitReverse), fn: bitReverseNaive[fext.E4]}, + {name: "BitReverse", buf: make([]fext.E4, maxSizeBitReverse), fn: BitReverse[fext.E4]}, +} + +func TestE4BitReverse(t *testing.T) { + + // generate a random []babybear.Element array of size 2**20 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkE4BitReverse(b *testing.B) { + // generate a random []E4 array of size 2**22 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range e4BitReverse { b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { b.ResetTimer() for j := 0; j < b.N; j++ { diff --git a/field/babybear/fft/fftext.go b/field/babybear/fft/fftext.go new file mode 100644 index 000000000..a0cf078c1 --- /dev/null +++ b/field/babybear/fft/fftext.go @@ -0,0 +1,407 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "math/big" + "math/bits" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/internal/parallel" + + "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" +) + +// FFTExt computes the discrete Fourier transform of a slice of extension field elements. +// Coefficients and evaluations are extension field elements. +// The root of unity domain is the same as FFT. +func (domain *Domain) FFTExt(a []fext.E4, decimation Decimation, opts ...Option) { + + opt := fftOptions(opts) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } + + // if coset != 0, scale by coset table + if opt.coset { + + if decimation == DIT { + + // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]babybear.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].MulByElement(&a[i], &cosetTable[irev]) + } + }, opt.nbTasks) + } else { + + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at babybear.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + + } + } + + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + if nbStages-twiddlesStartStage > 0 { + twiddles = make([][]babybear.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]babybear.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1<> nn) + a[i].MulByElement(&a[i], &cosetTableInv[irev]). + MulByElement(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) + +} + +func difFFTExt(a []fext.E4, w babybear.Element, twiddles [][]babybear.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + if n == 1<<8 { + kerDIFNP_256Ext(a, twiddles, stage-twiddlesStartStage) + return + } + } + m := n >> 1 + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at babybear.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddlesExt(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) + } else { + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + +} + +func innerDIFWithTwiddlesGenericExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddlesExt(a []fext.E4, at, w babybear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFTExt(a []fext.E4, w babybear.Element, twiddles [][]babybear.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + if n == 1<<8 { + kerDITNP_256Ext(a, twiddles, stage-twiddlesStartStage) + return + } + } + + m := n >> 1 + + nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) + + if stage < maxSplits { + // that's the only time we fire go routines + chDone := make(chan struct{}, 1) + + go ditFFTExt(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFTExt(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at babybear.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + + } else { + innerDITWithoutTwiddlesExt(a, w, w, 0, m, m) + } + return + } + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) +} + +func innerDITWithTwiddlesGenericExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + fext.Butterfly(&a[i], &a[i+m]) + } +} + +func innerDITWithoutTwiddlesExt(a []fext.E4, at, w babybear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &at) + fext.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } +} + +func kerDIFNP_256genericExt(a []fext.E4, twiddles [][]babybear.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + + innerDIFWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fext.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256genericExt(a []fext.E4, twiddles [][]babybear.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fext.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddlesGenericExt(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddlesGenericExt(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddlesGenericExt(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddlesGenericExt(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddlesGenericExt(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddlesGenericExt(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) +} diff --git a/field/babybear/fft/fftext_test.go b/field/babybear/fft/fftext_test.go new file mode 100644 index 000000000..51a70e5f5 --- /dev/null +++ b/field/babybear/fft/fftext_test.go @@ -0,0 +1,400 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "math/big" + "strconv" + "testing" + + "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "encoding/binary" + "fmt" + "github.com/stretchr/testify/require" + "math/rand/v2" +) + +func TestFFTExt(t *testing.T) { + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 6 + properties := gopter.NewProperties(parameters) + + for maxSize := 2; maxSize <= 1<<10; maxSize <<= 1 { + + domainWithPrecompute := NewDomain(uint64(maxSize)) + domainWithoutPrecompute := NewDomain(uint64(maxSize), WithoutPrecompute()) + + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF) + BitReverse(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF, OnCoset()) + BitReverse(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverse(pol) + domain.FFTExt(pol, DIT) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverse(pol) + domain.FFTExt(pol, DIT) + domain.FFTInverseExt(pol, DIF) + BitReverse(pol) + + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) + + for nbCosets := 2; nbCosets < 5; nbCosets++ { + properties.Property(fmt.Sprintf("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on %d cosets", nbCosets), prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + check := true + + for i := 1; i <= nbCosets; i++ { + + BitReverse(pol) + domain.FFTExt(pol, DIT, OnCoset()) + domain.FFTInverseExt(pol, DIF, OnCoset()) + BitReverse(pol) + + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + } + + return check + }, + )) + } + + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF) + domain.FFTExt(pol, DIT) + + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) + + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF, OnCoset()) + domain.FFTExt(pol, DIT, OnCoset()) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + // compute with nbTasks == 1 + domain.FFTInverseExt(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFTExt(pol, DIT, OnCoset(), WithNbTasks(1)) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + return true + }, + )) + } + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } + +} + +func randElementExt(rng *rand.Rand) fext.E4 { + var v fext.E4 + v.B0.A0 = babybear.Element{rng.Uint32N(2013265921)} + v.B0.A1 = babybear.Element{rng.Uint32N(2013265921)} + v.B1.A0 = babybear.Element{rng.Uint32N(2013265921)} + v.B1.A1 = babybear.Element{rng.Uint32N(2013265921)} + return v +} + +func FuzzFFTExt(f *testing.F) { + f.Fuzz(func(t *testing.T, domainSize uint16, rngSeed int64) { + if domainSize > (1 << 13) { + t.Skip("domain size too large") + } + if domainSize < 2 { + t.Skip("domain size too small") + } + + domain := NewDomain(uint64(domainSize)) + + var seed [32]byte + binary.PutVarint(seed[:], rngSeed) + // #nosec G404 -- fuzz does not require a cryptographic PRNG + rng := rand.New(rand.NewChaCha8(seed)) + + cardinality := domain.Cardinality + + // we just check that FFT-1(FFT(pol)) == pol + a, b := make([]fext.E4, cardinality), make([]fext.E4, cardinality) + for i := 0; i < int(cardinality); i++ { + a[i] = randElementExt(rng) + } + copy(b, a) + + domain.FFTInverseExt(a, DIF) + domain.FFTExt(a, DIT) + + assert := require.New(t) + for i := 0; i < int(cardinality); i++ { + assert.True(a[i].Equal(&b[i]), "FFT-1(FFT(pol)) != pol at index %d", i) + } + }) +} + +// -------------------------------------------------------------------- +// benches + +func BenchmarkFFTExt(b *testing.B) { + + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + for i := 8; i < 20; i++ { + sizeDomain := 1 << i + b.Run("fft 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT) + } + }) + b.Run("fft 2**"+strconv.Itoa(i)+"bits (coset)", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT, OnCoset()) + } + }) + } + +} + +func BenchmarkFFTDITCosetReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIT, OnCoset()) + } +} + +func BenchmarkFFTDITReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < 1; j++ { + domain.FFTExt(pol, DIT) + } +} + +func BenchmarkFFTDIFReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} +func BenchmarkFFTDIFReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} + +func evaluatePolynomialExt(pol []fext.E4, val babybear.Element) fext.E4 { + var res, tmp fext.E4 + var acc babybear.Element + res.Set(&pol[0]) + acc.Set(&val) + for i := 1; i < len(pol); i++ { + tmp.MulByElement(&pol[i], &acc) + res.Add(&res, &tmp) + acc.Mul(&acc, &val) + } + return res +} diff --git a/field/babybear/fft/kernel_purego.go b/field/babybear/fft/kernel_purego.go index 2d04cd8b9..a8e2167d8 100644 --- a/field/babybear/fft/kernel_purego.go +++ b/field/babybear/fft/kernel_purego.go @@ -9,6 +9,7 @@ package fft import ( "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" ) func innerDIFWithTwiddles(a []babybear.Element, twiddles []babybear.Element, start, end, m int) { @@ -25,3 +26,18 @@ func kerDIFNP_256(a []babybear.Element, twiddles [][]babybear.Element, stage int func kerDITNP_256(a []babybear.Element, twiddles [][]babybear.Element, stage int) { kerDITNP_256generic(a, twiddles, stage) } + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func kerDIFNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { + kerDIFNP_256genericExt(a, twiddles, stage) +} +func kerDITNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { + kerDITNP_256genericExt(a, twiddles, stage) +} diff --git a/field/generator/generator_fft.go b/field/generator/generator_fft.go index 386ca43d8..a924f79fb 100644 --- a/field/generator/generator_fft.go +++ b/field/generator/generator_fft.go @@ -58,7 +58,10 @@ func generateFFT(F *config.Field, fft *config.FFT, outputDir string) error { {File: filepath.Join(outputDir, "bitreverse.go"), Templates: []string{"bitreverse.go.tmpl"}}, {File: filepath.Join(outputDir, "options.go"), Templates: []string{"options.go.tmpl"}}, } - + if F.F31 { + entries = append(entries, bavard.Entry{File: filepath.Join(outputDir, "fftext_test.go"), Templates: []string{"tests/fftext.go.tmpl"}}) + entries = append(entries, bavard.Entry{File: filepath.Join(outputDir, "fftext.go"), Templates: []string{"fftext.go.tmpl"}}) + } if data.HasASMKernel { data.Q = F.Q[0] data.QInvNeg = F.QInverse[0] diff --git a/field/generator/internal/templates/fft/bitreverse.go.tmpl b/field/generator/internal/templates/fft/bitreverse.go.tmpl index 47f9ea75a..d917c2faf 100644 --- a/field/generator/internal/templates/fft/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/bitreverse.go.tmpl @@ -257,5 +257,4 @@ func bitReverseCobraInPlace_{{.logTileSize}}_{{.logN}}(v []{{ .FF }}.Element) { } -{{- end}} - +{{- end}} \ No newline at end of file diff --git a/field/generator/internal/templates/fft/fftext.go .tmpl b/field/generator/internal/templates/fft/fftext.go.tmpl similarity index 81% rename from field/generator/internal/templates/fft/fftext.go .tmpl rename to field/generator/internal/templates/fft/fftext.go.tmpl index 2c8fe52d3..52cb56c24 100644 --- a/field/generator/internal/templates/fft/fftext.go .tmpl +++ b/field/generator/internal/templates/fft/fftext.go.tmpl @@ -4,7 +4,7 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/internal/parallel" - + "{{ .FieldPackagePath }}" fext "{{ .FieldPackagePath }}/extensions" ) @@ -359,3 +359,93 @@ func ditFFTExt(a []fext.E4, w {{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, } {{- end}} } + + +func innerDITWithTwiddlesGenericExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + fext.Butterfly(&a[i], &a[i+m]) + } +} + +func innerDITWithoutTwiddlesExt(a []fext.E4, at, w {{ .FF }}.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &at) + fext.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } +} + + +{{range $ki, $klog2 := $.Kernels}} + {{$ksize := shl 1 $klog2}} + {{genKernel $.FF $ksize $klog2}} +{{end}} + +{{define "genKernel FF sizeKernel sizeKernelLog2"}} + +func kerDIFNP_{{.sizeKernel}}genericExt(a []fext.E4, twiddles [][]{{ .FF }}.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + + {{ $n := shl 1 .sizeKernelLog2}} + {{ $m := div $n 2}} + {{ $split := 1}} + {{- range $step := iterate 0 .sizeKernelLog2}} + {{- $offset := 0}} + + {{- $bound := mul $split $n}} + {{- if eq $bound $n}} + innerDIFWithTwiddlesGenericExt(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- else}} + for offset := 0; offset < {{$bound}}; offset += {{$n}} { + {{- if eq $m 1}} + fext.Butterfly(&a[offset], &a[offset+1]) + {{- else}} + innerDIFWithTwiddlesGenericExt(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- end}} + } + {{- end}} + + {{- $n = div $n 2}} + {{- $m = div $n 2}} + {{- $split = mul $split 2}} + {{- end}} +} + +func kerDITNP_{{.sizeKernel}}genericExt(a []fext.E4, twiddles [][]{{ .FF }}.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + {{ $n := 2}} + {{ $m := div $n 2}} + {{ $split := div (shl 1 .sizeKernelLog2) 2}} + {{- range $step := reverse (iterate 0 .sizeKernelLog2)}} + {{- $offset := 0}} + + {{- $bound := mul $split $n}} + {{- if eq $bound $n}} + innerDITWithTwiddlesGenericExt(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- else}} + for offset := 0; offset < {{$bound}}; offset += {{$n}} { + {{- if eq $m 1}} + fext.Butterfly(&a[offset], &a[offset+1]) + {{- else}} + innerDITWithTwiddlesGenericExt(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- end}} + } + {{- end}} + + {{- $n = mul $n 2}} + {{- $m = div $n 2}} + {{- $split = div $split 2}} + {{- end}} +} + +{{end}} + diff --git a/field/generator/internal/templates/fft/kernel.purego.go.tmpl b/field/generator/internal/templates/fft/kernel.purego.go.tmpl index 35cce050c..24eff28b9 100644 --- a/field/generator/internal/templates/fft/kernel.purego.go.tmpl +++ b/field/generator/internal/templates/fft/kernel.purego.go.tmpl @@ -33,10 +33,10 @@ func innerDITWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, e } {{range $ki, $klog2 := $.Kernels}} {{- $ksize := shl 1 $klog2}} -func kerDIFNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ .FF }}.Element, stage int) { +func kerDIFNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { kerDIFNP_{{$ksize}}genericExt(a, twiddles, stage) } -func kerDITNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ .FF }}.Element, stage int) { +func kerDITNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { kerDITNP_{{$ksize}}genericExt(a, twiddles, stage) } {{end}} diff --git a/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl b/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl index 5a47f1b35..f42188990 100644 --- a/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl @@ -101,6 +101,7 @@ func BenchmarkBitReverse(b *testing.B) { } } } + {{- else}} type bitReverseVariant[T SmallField] struct { name string @@ -272,4 +273,4 @@ func BenchmarkE4BitReverse(b *testing.B) { } } } -{{- end}} +{{- end}} \ No newline at end of file diff --git a/field/generator/internal/templates/fft/tests/fftext.go.tmpl b/field/generator/internal/templates/fft/tests/fftext.go.tmpl index c46becb62..e57c4f147 100644 --- a/field/generator/internal/templates/fft/tests/fftext.go.tmpl +++ b/field/generator/internal/templates/fft/tests/fftext.go.tmpl @@ -396,4 +396,4 @@ func evaluatePolynomialExt(pol []fext.E4, val {{ .FF }}.Element) fext.E4 { acc.Mul(&acc, &val) } return res -} +} \ No newline at end of file diff --git a/field/koalabear/fft/fftext.go b/field/koalabear/fft/fftext.go index eb17dbeb8..5d57b9f75 100644 --- a/field/koalabear/fft/fftext.go +++ b/field/koalabear/fft/fftext.go @@ -1,3 +1,8 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + package fft import ( @@ -321,7 +326,6 @@ func ditFFTExt(a []fext.E4, w koalabear.Element, twiddles [][]koalabear.Element, } return } - innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) } @@ -330,13 +334,10 @@ func innerDITWithTwiddlesGenericExt(a []fext.E4, twiddles []koalabear.Element, s fext.Butterfly(&a[0], &a[m]) start++ } - for i := start; i < end; i++ { a[i+m].MulByElement(&a[i+m], &twiddles[i]) fext.Butterfly(&a[i], &a[i+m]) - } - } func innerDITWithoutTwiddlesExt(a []fext.E4, at, w koalabear.Element, start, end, m int) { @@ -352,7 +353,7 @@ func innerDITWithoutTwiddlesExt(a []fext.E4, at, w koalabear.Element, start, end } func kerDIFNP_256genericExt(a []fext.E4, twiddles [][]koalabear.Element, stage int) { - // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl innerDIFWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) for offset := 0; offset < 256; offset += 128 { @@ -379,7 +380,7 @@ func kerDIFNP_256genericExt(a []fext.E4, twiddles [][]koalabear.Element, stage i } func kerDITNP_256genericExt(a []fext.E4, twiddles [][]koalabear.Element, stage int) { - // code unrolled & generated by internal/generator/fft/template/fft.go.tmpl + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl for offset := 0; offset < 256; offset += 2 { fext.Butterfly(&a[offset], &a[offset+1]) diff --git a/field/koalabear/fft/fftext_test.go b/field/koalabear/fft/fftext_test.go index 21343e4b5..3bada0049 100644 --- a/field/koalabear/fft/fftext_test.go +++ b/field/koalabear/fft/fftext_test.go @@ -1,16 +1,5 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. // Code generated by consensys/gnark-crypto DO NOT EDIT @@ -21,19 +10,17 @@ import ( "strconv" "testing" - "fmt" + "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" + "github.com/leanovate/gopter" "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter/prop" "encoding/binary" - - "math/rand/v2" - + "fmt" "github.com/stretchr/testify/require" - - "github.com/consensys/gnark-crypto/field/koalabear" - fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" + "math/rand/v2" ) func TestFFTExt(t *testing.T) { From a9eabc5d24691544e9f70cf5b2288c6c40e782e5 Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Tue, 13 May 2025 14:23:03 +0200 Subject: [PATCH 14/15] feat: define missing fuctions relate to avx512 --- field/babybear/fft/kernel_amd64.go | 65 +++++++++++++++++ .../templates/fft/kernel.amd64.go.tmpl | 71 +++++++++++++++++++ field/koalabear/fft/kernel_amd64.go | 65 +++++++++++++++++ 3 files changed, 201 insertions(+) diff --git a/field/babybear/fft/kernel_amd64.go b/field/babybear/fft/kernel_amd64.go index c555efe1d..0d91e5d70 100644 --- a/field/babybear/fft/kernel_amd64.go +++ b/field/babybear/fft/kernel_amd64.go @@ -9,6 +9,7 @@ package fft import ( "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" "github.com/consensys/gnark-crypto/utils/cpu" ) @@ -65,3 +66,67 @@ func kerDITNP_256(a []babybear.Element, twiddles [][]babybear.Element, stage int } kerDITNP_256_avx512(a, twiddles, stage) } + +func ConvertE4SliceToCoefficientSlices(input []fext.E4) [4][]babybear.Element { + n := len(input) + + // Create the four output slices, each with the same length as the input slice + outputC0 := make([]babybear.Element, n) + outputC1 := make([]babybear.Element, n) + outputC2 := make([]babybear.Element, n) + outputC3 := make([]babybear.Element, n) + + // Iterate through the input slice and distribute the coefficients + for i := 0; i < n; i++ { + e4Element := input[i] // Get the current fext.E4 element + + // Extract coefficients and place them into the corresponding output slices + outputC0[i] = e4Element.B0.A0 + outputC1[i] = e4Element.B0.A1 + outputC2[i] = e4Element.B1.A0 + outputC3[i] = e4Element.B1.A1 + } + + // Return the four slices packaged in an array + return [4][]babybear.Element{outputC0, outputC1, outputC2, outputC3} +} + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + for _, v := range ConvertE4SliceToCoefficientSlices(a) { + innerDIFWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) + } +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + for _, v := range ConvertE4SliceToCoefficientSlices(a) { + innerDITWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) + } +} + +func kerDIFNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { + if !cpu.SupportAVX512 { + kerDIFNP_256genericExt(a, twiddles, stage) + return + } + for _, v := range ConvertE4SliceToCoefficientSlices(a) { + kerDIFNP_256_avx512(v, twiddles, stage) + } +} + +func kerDITNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { + if !cpu.SupportAVX512 { + kerDITNP_256genericExt(a, twiddles, stage) + return + } + for _, v := range ConvertE4SliceToCoefficientSlices(a) { + kerDITNP_256_avx512(v, twiddles, stage) + } +} diff --git a/field/generator/internal/templates/fft/kernel.amd64.go.tmpl b/field/generator/internal/templates/fft/kernel.amd64.go.tmpl index 722961203..3cdce8dbf 100644 --- a/field/generator/internal/templates/fft/kernel.amd64.go.tmpl +++ b/field/generator/internal/templates/fft/kernel.amd64.go.tmpl @@ -1,6 +1,8 @@ import ( "github.com/consensys/gnark-crypto/utils/cpu" "{{ .FieldPackagePath }}" + fext "{{ .FieldPackagePath }}/extensions" + ) @@ -63,3 +65,72 @@ func kerDITNP_{{$ksize}}(a []{{ $.FF }}.Element, twiddles [][]{{ $.FF }}.Element } {{end}} + +func ConvertE4SliceToCoefficientSlices(input []fext.E4) [4][]{{ .FF }}.Element { + n := len(input) + + // Create the four output slices, each with the same length as the input slice + outputC0 := make([]{{ .FF }}.Element, n) + outputC1 := make([]{{ .FF }}.Element, n) + outputC2 := make([]{{ .FF }}.Element, n) + outputC3 := make([]{{ .FF }}.Element, n) + + // Iterate through the input slice and distribute the coefficients + for i := 0; i < n; i++ { + e4Element := input[i] // Get the current fext.E4 element + + // Extract coefficients and place them into the corresponding output slices + outputC0[i] = e4Element.B0.A0 + outputC1[i] = e4Element.B0.A1 + outputC2[i] = e4Element.B1.A0 + outputC3[i] = e4Element.B1.A1 + } + + // Return the four slices packaged in an array + return [4][]{{ .FF }}.Element{outputC0, outputC1, outputC2, outputC3} +} + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + for _, v :=range ConvertE4SliceToCoefficientSlices(a) { + innerDIFWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) + } +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + for _, v :=range ConvertE4SliceToCoefficientSlices(a) { + innerDITWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) + } +} + + +{{range $ki, $klog2 := $.Kernels}} + {{- $ksize := shl 1 $klog2}} + +func kerDIFNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { + if !cpu.SupportAVX512 { + kerDIFNP_{{$ksize}}genericExt(a, twiddles, stage) + return + } + for _, v :=range ConvertE4SliceToCoefficientSlices(a) { + kerDIFNP_{{$ksize}}_avx512(v, twiddles, stage) + } +} + +func kerDITNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { + if !cpu.SupportAVX512 { + kerDITNP_{{$ksize}}genericExt(a, twiddles, stage) + return + } + for _, v :=range ConvertE4SliceToCoefficientSlices(a) { + kerDITNP_{{$ksize}}_avx512(v, twiddles, stage) + } +} +{{end}} \ No newline at end of file diff --git a/field/koalabear/fft/kernel_amd64.go b/field/koalabear/fft/kernel_amd64.go index 7abce70f4..d8f6a14e7 100644 --- a/field/koalabear/fft/kernel_amd64.go +++ b/field/koalabear/fft/kernel_amd64.go @@ -9,6 +9,7 @@ package fft import ( "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" "github.com/consensys/gnark-crypto/utils/cpu" ) @@ -65,3 +66,67 @@ func kerDITNP_256(a []koalabear.Element, twiddles [][]koalabear.Element, stage i } kerDITNP_256_avx512(a, twiddles, stage) } + +func ConvertE4SliceToCoefficientSlices(input []fext.E4) [4][]koalabear.Element { + n := len(input) + + // Create the four output slices, each with the same length as the input slice + outputC0 := make([]koalabear.Element, n) + outputC1 := make([]koalabear.Element, n) + outputC2 := make([]koalabear.Element, n) + outputC3 := make([]koalabear.Element, n) + + // Iterate through the input slice and distribute the coefficients + for i := 0; i < n; i++ { + e4Element := input[i] // Get the current fext.E4 element + + // Extract coefficients and place them into the corresponding output slices + outputC0[i] = e4Element.B0.A0 + outputC1[i] = e4Element.B0.A1 + outputC2[i] = e4Element.B1.A0 + outputC3[i] = e4Element.B1.A1 + } + + // Return the four slices packaged in an array + return [4][]koalabear.Element{outputC0, outputC1, outputC2, outputC3} +} + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + for _, v := range ConvertE4SliceToCoefficientSlices(a) { + innerDIFWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) + } +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + for _, v := range ConvertE4SliceToCoefficientSlices(a) { + innerDITWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) + } +} + +func kerDIFNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + if !cpu.SupportAVX512 { + kerDIFNP_256genericExt(a, twiddles, stage) + return + } + for _, v := range ConvertE4SliceToCoefficientSlices(a) { + kerDIFNP_256_avx512(v, twiddles, stage) + } +} + +func kerDITNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + if !cpu.SupportAVX512 { + kerDITNP_256genericExt(a, twiddles, stage) + return + } + for _, v := range ConvertE4SliceToCoefficientSlices(a) { + kerDITNP_256_avx512(v, twiddles, stage) + } +} From 642026c65b898366d91fecec015e028cd55809dd Mon Sep 17 00:00:00 2001 From: Yao Galteland Date: Wed, 14 May 2025 21:26:07 +0200 Subject: [PATCH 15/15] docs: mark the todo for using avx512 --- field/babybear/fft/kernel_amd64.go | 40 ++---------------- .../templates/fft/kernel.amd64.go.tmpl | 41 ++----------------- field/koalabear/fft/kernel_amd64.go | 40 ++---------------- 3 files changed, 12 insertions(+), 109 deletions(-) diff --git a/field/babybear/fft/kernel_amd64.go b/field/babybear/fft/kernel_amd64.go index 0d91e5d70..f74883271 100644 --- a/field/babybear/fft/kernel_amd64.go +++ b/field/babybear/fft/kernel_amd64.go @@ -67,38 +67,12 @@ func kerDITNP_256(a []babybear.Element, twiddles [][]babybear.Element, stage int kerDITNP_256_avx512(a, twiddles, stage) } -func ConvertE4SliceToCoefficientSlices(input []fext.E4) [4][]babybear.Element { - n := len(input) - - // Create the four output slices, each with the same length as the input slice - outputC0 := make([]babybear.Element, n) - outputC1 := make([]babybear.Element, n) - outputC2 := make([]babybear.Element, n) - outputC3 := make([]babybear.Element, n) - - // Iterate through the input slice and distribute the coefficients - for i := 0; i < n; i++ { - e4Element := input[i] // Get the current fext.E4 element - - // Extract coefficients and place them into the corresponding output slices - outputC0[i] = e4Element.B0.A0 - outputC1[i] = e4Element.B0.A1 - outputC2[i] = e4Element.B1.A0 - outputC3[i] = e4Element.B1.A1 - } - - // Return the four slices packaged in an array - return [4][]babybear.Element{outputC0, outputC1, outputC2, outputC3} -} - func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { if !cpu.SupportAVX512 || m < 16 { innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) return } - for _, v := range ConvertE4SliceToCoefficientSlices(a) { - innerDIFWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) - } + //todo: use AVX512 } func innerDITWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { @@ -106,9 +80,7 @@ func innerDITWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, en innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) return } - for _, v := range ConvertE4SliceToCoefficientSlices(a) { - innerDITWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) - } + //todo: use AVX512 } func kerDIFNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { @@ -116,9 +88,7 @@ func kerDIFNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { kerDIFNP_256genericExt(a, twiddles, stage) return } - for _, v := range ConvertE4SliceToCoefficientSlices(a) { - kerDIFNP_256_avx512(v, twiddles, stage) - } + //todo: use AVX512 } func kerDITNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { @@ -126,7 +96,5 @@ func kerDITNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { kerDITNP_256genericExt(a, twiddles, stage) return } - for _, v := range ConvertE4SliceToCoefficientSlices(a) { - kerDITNP_256_avx512(v, twiddles, stage) - } + //todo: use AVX512 } diff --git a/field/generator/internal/templates/fft/kernel.amd64.go.tmpl b/field/generator/internal/templates/fft/kernel.amd64.go.tmpl index 3cdce8dbf..cbb1bbfd4 100644 --- a/field/generator/internal/templates/fft/kernel.amd64.go.tmpl +++ b/field/generator/internal/templates/fft/kernel.amd64.go.tmpl @@ -65,39 +65,12 @@ func kerDITNP_{{$ksize}}(a []{{ $.FF }}.Element, twiddles [][]{{ $.FF }}.Element } {{end}} - -func ConvertE4SliceToCoefficientSlices(input []fext.E4) [4][]{{ .FF }}.Element { - n := len(input) - - // Create the four output slices, each with the same length as the input slice - outputC0 := make([]{{ .FF }}.Element, n) - outputC1 := make([]{{ .FF }}.Element, n) - outputC2 := make([]{{ .FF }}.Element, n) - outputC3 := make([]{{ .FF }}.Element, n) - - // Iterate through the input slice and distribute the coefficients - for i := 0; i < n; i++ { - e4Element := input[i] // Get the current fext.E4 element - - // Extract coefficients and place them into the corresponding output slices - outputC0[i] = e4Element.B0.A0 - outputC1[i] = e4Element.B0.A1 - outputC2[i] = e4Element.B1.A0 - outputC3[i] = e4Element.B1.A1 - } - - // Return the four slices packaged in an array - return [4][]{{ .FF }}.Element{outputC0, outputC1, outputC2, outputC3} -} - func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { if !cpu.SupportAVX512 || m < 16 { innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) return } - for _, v :=range ConvertE4SliceToCoefficientSlices(a) { - innerDIFWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) - } + //todo: use AVX512 } func innerDITWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { @@ -105,9 +78,7 @@ func innerDITWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, e innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) return } - for _, v :=range ConvertE4SliceToCoefficientSlices(a) { - innerDITWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) - } + //todo: use AVX512 } @@ -119,9 +90,7 @@ func kerDIFNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage kerDIFNP_{{$ksize}}genericExt(a, twiddles, stage) return } - for _, v :=range ConvertE4SliceToCoefficientSlices(a) { - kerDIFNP_{{$ksize}}_avx512(v, twiddles, stage) - } + //todo: use AVX512 } func kerDITNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { @@ -129,8 +98,6 @@ func kerDITNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage kerDITNP_{{$ksize}}genericExt(a, twiddles, stage) return } - for _, v :=range ConvertE4SliceToCoefficientSlices(a) { - kerDITNP_{{$ksize}}_avx512(v, twiddles, stage) - } + //todo: use AVX512 } {{end}} \ No newline at end of file diff --git a/field/koalabear/fft/kernel_amd64.go b/field/koalabear/fft/kernel_amd64.go index d8f6a14e7..fb2410724 100644 --- a/field/koalabear/fft/kernel_amd64.go +++ b/field/koalabear/fft/kernel_amd64.go @@ -67,38 +67,12 @@ func kerDITNP_256(a []koalabear.Element, twiddles [][]koalabear.Element, stage i kerDITNP_256_avx512(a, twiddles, stage) } -func ConvertE4SliceToCoefficientSlices(input []fext.E4) [4][]koalabear.Element { - n := len(input) - - // Create the four output slices, each with the same length as the input slice - outputC0 := make([]koalabear.Element, n) - outputC1 := make([]koalabear.Element, n) - outputC2 := make([]koalabear.Element, n) - outputC3 := make([]koalabear.Element, n) - - // Iterate through the input slice and distribute the coefficients - for i := 0; i < n; i++ { - e4Element := input[i] // Get the current fext.E4 element - - // Extract coefficients and place them into the corresponding output slices - outputC0[i] = e4Element.B0.A0 - outputC1[i] = e4Element.B0.A1 - outputC2[i] = e4Element.B1.A0 - outputC3[i] = e4Element.B1.A1 - } - - // Return the four slices packaged in an array - return [4][]koalabear.Element{outputC0, outputC1, outputC2, outputC3} -} - func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { if !cpu.SupportAVX512 || m < 16 { innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) return } - for _, v := range ConvertE4SliceToCoefficientSlices(a) { - innerDIFWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) - } + //todo: use AVX512 } func innerDITWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { @@ -106,9 +80,7 @@ func innerDITWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, e innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) return } - for _, v := range ConvertE4SliceToCoefficientSlices(a) { - innerDITWithTwiddles_avx512(&v[0], &twiddles[0], start, end, m) - } + //todo: use AVX512 } func kerDIFNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { @@ -116,9 +88,7 @@ func kerDIFNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { kerDIFNP_256genericExt(a, twiddles, stage) return } - for _, v := range ConvertE4SliceToCoefficientSlices(a) { - kerDIFNP_256_avx512(v, twiddles, stage) - } + //todo: use AVX512 } func kerDITNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { @@ -126,7 +96,5 @@ func kerDITNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { kerDITNP_256genericExt(a, twiddles, stage) return } - for _, v := range ConvertE4SliceToCoefficientSlices(a) { - kerDITNP_256_avx512(v, twiddles, stage) - } + //todo: use AVX512 }