diff --git a/field/babybear/extensions/e4.go b/field/babybear/extensions/e4.go index 957710e3e..30d4d2115 100644 --- a/field/babybear/extensions/e4.go +++ b/field/babybear/extensions/e4.go @@ -369,3 +369,12 @@ func MulAccE4(alpha *E4, scale []fr.Element, res []E4) { mulAccE4_avx512(alpha, &scale[0], &res[0], uint64(N)) } + +// Butterfly computes the butterfly operation on two E4 elements +func Butterfly(a, b *E4) { + fr.Butterfly(&a.B0.A0, &b.B0.A0) + fr.Butterfly(&a.B0.A1, &b.B0.A1) + + fr.Butterfly(&a.B1.A0, &b.B1.A0) + fr.Butterfly(&a.B1.A1, &b.B1.A1) +} diff --git a/field/babybear/fft/bitreverse.go b/field/babybear/fft/bitreverse.go index 6c376c898..d604ba7ca 100644 --- a/field/babybear/fft/bitreverse.go +++ b/field/babybear/fft/bitreverse.go @@ -9,11 +9,16 @@ import ( "math/bits" "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" ) +type SmallField interface { + babybear.Element | fext.E4 +} + // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func BitReverse(v []babybear.Element) { +func BitReverse[T SmallField](v []T) { n := uint64(len(v)) if bits.OnesCount64(n) != 1 { panic("len(a) must be a power of 2") @@ -24,7 +29,7 @@ func BitReverse(v []babybear.Element) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func bitReverseNaive(v []babybear.Element) { +func bitReverseNaive[T SmallField](v []T) { n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) diff --git a/field/babybear/fft/bitreverse_test.go b/field/babybear/fft/bitreverse_test.go index 96aa6f323..6d093091a 100644 --- a/field/babybear/fft/bitreverse_test.go +++ b/field/babybear/fft/bitreverse_test.go @@ -10,22 +10,23 @@ import ( "testing" "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" ) -type bitReverseVariant struct { +type bitReverseVariant[T SmallField] struct { name string - buf []babybear.Element - fn func([]babybear.Element) + buf []T + fn func([]T) } const maxSizeBitReverse = 1 << 23 -var bitReverse = []bitReverseVariant{ - {name: "bitReverseNaive", buf: make([]babybear.Element, maxSizeBitReverse), fn: bitReverseNaive}, - {name: "BitReverse", buf: make([]babybear.Element, maxSizeBitReverse), fn: BitReverse}, +var babybearBitReverse = []bitReverseVariant[babybear.Element]{ + {name: "bitReverseNaive", buf: make([]babybear.Element, maxSizeBitReverse), fn: bitReverseNaive[babybear.Element]}, + {name: "BitReverse", buf: make([]babybear.Element, maxSizeBitReverse), fn: BitReverse[babybear.Element]}, } -func TestBitReverse(t *testing.T) { +func TestElementBitReverse(t *testing.T) { // generate a random []babybear.Element array of size 2**20 pol := make([]babybear.Element, maxSizeBitReverse) @@ -39,33 +40,33 @@ func TestBitReverse(t *testing.T) { for size := 2; size <= maxSizeBitReverse; size <<= 1 { // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range babybearBitReverse { copy(data.buf, pol[:size]) } // compute bit reverse shuffling - for _, data := range bitReverse { + for _, data := range babybearBitReverse { data.fn(data.buf[:size]) } // all bitReverse.buf should hold the same result for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(babybearBitReverse); j++ { + if !babybearBitReverse[0].buf[i].Equal(&babybearBitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", babybearBitReverse[0].name, babybearBitReverse[j].name) } } } // bitReverse back should be identity - for _, data := range bitReverse { + for _, data := range babybearBitReverse { data.fn(data.buf[:size]) } for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(babybearBitReverse); j++ { + if !babybearBitReverse[0].buf[i].Equal(&babybearBitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", babybearBitReverse[0].name, babybearBitReverse[j].name) } } } @@ -73,7 +74,7 @@ func TestBitReverse(t *testing.T) { } -func BenchmarkBitReverse(b *testing.B) { +func BenchmarkElementBitReverse(b *testing.B) { // generate a random []babybear.Element array of size 2**22 pol := make([]babybear.Element, maxSizeBitReverse) one := babybear.One() @@ -83,13 +84,95 @@ func BenchmarkBitReverse(b *testing.B) { } // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range babybearBitReverse { copy(data.buf, pol[:maxSizeBitReverse]) } // benchmark for each size, each bitReverse function for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { - for _, data := range bitReverse { + for _, data := range babybearBitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} + +var e4BitReverse = []bitReverseVariant[fext.E4]{ + {name: "bitReverseNaive", buf: make([]fext.E4, maxSizeBitReverse), fn: bitReverseNaive[fext.E4]}, + {name: "BitReverse", buf: make([]fext.E4, maxSizeBitReverse), fn: BitReverse[fext.E4]}, +} + +func TestE4BitReverse(t *testing.T) { + + // generate a random []babybear.Element array of size 2**20 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkE4BitReverse(b *testing.B) { + // generate a random []E4 array of size 2**22 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range e4BitReverse { b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { b.ResetTimer() for j := 0; j < b.N; j++ { diff --git a/field/babybear/fft/fftext.go b/field/babybear/fft/fftext.go new file mode 100644 index 000000000..a0cf078c1 --- /dev/null +++ b/field/babybear/fft/fftext.go @@ -0,0 +1,407 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "math/big" + "math/bits" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/internal/parallel" + + "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" +) + +// FFTExt computes the discrete Fourier transform of a slice of extension field elements. +// Coefficients and evaluations are extension field elements. +// The root of unity domain is the same as FFT. +func (domain *Domain) FFTExt(a []fext.E4, decimation Decimation, opts ...Option) { + + opt := fftOptions(opts) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } + + // if coset != 0, scale by coset table + if opt.coset { + + if decimation == DIT { + + // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]babybear.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].MulByElement(&a[i], &cosetTable[irev]) + } + }, opt.nbTasks) + } else { + + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at babybear.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + + } + } + + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + if nbStages-twiddlesStartStage > 0 { + twiddles = make([][]babybear.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]babybear.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1<> nn) + a[i].MulByElement(&a[i], &cosetTableInv[irev]). + MulByElement(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) + +} + +func difFFTExt(a []fext.E4, w babybear.Element, twiddles [][]babybear.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + if n == 1<<8 { + kerDIFNP_256Ext(a, twiddles, stage-twiddlesStartStage) + return + } + } + m := n >> 1 + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at babybear.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddlesExt(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) + } else { + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + +} + +func innerDIFWithTwiddlesGenericExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddlesExt(a []fext.E4, at, w babybear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFTExt(a []fext.E4, w babybear.Element, twiddles [][]babybear.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + if n == 1<<8 { + kerDITNP_256Ext(a, twiddles, stage-twiddlesStartStage) + return + } + } + + m := n >> 1 + + nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) + + if stage < maxSplits { + // that's the only time we fire go routines + chDone := make(chan struct{}, 1) + + go ditFFTExt(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFTExt(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at babybear.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + + } else { + innerDITWithoutTwiddlesExt(a, w, w, 0, m, m) + } + return + } + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) +} + +func innerDITWithTwiddlesGenericExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + fext.Butterfly(&a[i], &a[i+m]) + } +} + +func innerDITWithoutTwiddlesExt(a []fext.E4, at, w babybear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &at) + fext.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } +} + +func kerDIFNP_256genericExt(a []fext.E4, twiddles [][]babybear.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + + innerDIFWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fext.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256genericExt(a []fext.E4, twiddles [][]babybear.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fext.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddlesGenericExt(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddlesGenericExt(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddlesGenericExt(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddlesGenericExt(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddlesGenericExt(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddlesGenericExt(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) +} diff --git a/field/babybear/fft/fftext_test.go b/field/babybear/fft/fftext_test.go new file mode 100644 index 000000000..51a70e5f5 --- /dev/null +++ b/field/babybear/fft/fftext_test.go @@ -0,0 +1,400 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "math/big" + "strconv" + "testing" + + "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "encoding/binary" + "fmt" + "github.com/stretchr/testify/require" + "math/rand/v2" +) + +func TestFFTExt(t *testing.T) { + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 6 + properties := gopter.NewProperties(parameters) + + for maxSize := 2; maxSize <= 1<<10; maxSize <<= 1 { + + domainWithPrecompute := NewDomain(uint64(maxSize)) + domainWithoutPrecompute := NewDomain(uint64(maxSize), WithoutPrecompute()) + + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF) + BitReverse(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF, OnCoset()) + BitReverse(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverse(pol) + domain.FFTExt(pol, DIT) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverse(pol) + domain.FFTExt(pol, DIT) + domain.FFTInverseExt(pol, DIF) + BitReverse(pol) + + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) + + for nbCosets := 2; nbCosets < 5; nbCosets++ { + properties.Property(fmt.Sprintf("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on %d cosets", nbCosets), prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + check := true + + for i := 1; i <= nbCosets; i++ { + + BitReverse(pol) + domain.FFTExt(pol, DIT, OnCoset()) + domain.FFTInverseExt(pol, DIF, OnCoset()) + BitReverse(pol) + + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + } + + return check + }, + )) + } + + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF) + domain.FFTExt(pol, DIT) + + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) + + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF, OnCoset()) + domain.FFTExt(pol, DIT, OnCoset()) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + // compute with nbTasks == 1 + domain.FFTInverseExt(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFTExt(pol, DIT, OnCoset(), WithNbTasks(1)) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + return true + }, + )) + } + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } + +} + +func randElementExt(rng *rand.Rand) fext.E4 { + var v fext.E4 + v.B0.A0 = babybear.Element{rng.Uint32N(2013265921)} + v.B0.A1 = babybear.Element{rng.Uint32N(2013265921)} + v.B1.A0 = babybear.Element{rng.Uint32N(2013265921)} + v.B1.A1 = babybear.Element{rng.Uint32N(2013265921)} + return v +} + +func FuzzFFTExt(f *testing.F) { + f.Fuzz(func(t *testing.T, domainSize uint16, rngSeed int64) { + if domainSize > (1 << 13) { + t.Skip("domain size too large") + } + if domainSize < 2 { + t.Skip("domain size too small") + } + + domain := NewDomain(uint64(domainSize)) + + var seed [32]byte + binary.PutVarint(seed[:], rngSeed) + // #nosec G404 -- fuzz does not require a cryptographic PRNG + rng := rand.New(rand.NewChaCha8(seed)) + + cardinality := domain.Cardinality + + // we just check that FFT-1(FFT(pol)) == pol + a, b := make([]fext.E4, cardinality), make([]fext.E4, cardinality) + for i := 0; i < int(cardinality); i++ { + a[i] = randElementExt(rng) + } + copy(b, a) + + domain.FFTInverseExt(a, DIF) + domain.FFTExt(a, DIT) + + assert := require.New(t) + for i := 0; i < int(cardinality); i++ { + assert.True(a[i].Equal(&b[i]), "FFT-1(FFT(pol)) != pol at index %d", i) + } + }) +} + +// -------------------------------------------------------------------- +// benches + +func BenchmarkFFTExt(b *testing.B) { + + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + for i := 8; i < 20; i++ { + sizeDomain := 1 << i + b.Run("fft 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT) + } + }) + b.Run("fft 2**"+strconv.Itoa(i)+"bits (coset)", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT, OnCoset()) + } + }) + } + +} + +func BenchmarkFFTDITCosetReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIT, OnCoset()) + } +} + +func BenchmarkFFTDITReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < 1; j++ { + domain.FFTExt(pol, DIT) + } +} + +func BenchmarkFFTDIFReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} +func BenchmarkFFTDIFReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} + +func evaluatePolynomialExt(pol []fext.E4, val babybear.Element) fext.E4 { + var res, tmp fext.E4 + var acc babybear.Element + res.Set(&pol[0]) + acc.Set(&val) + for i := 1; i < len(pol); i++ { + tmp.MulByElement(&pol[i], &acc) + res.Add(&res, &tmp) + acc.Mul(&acc, &val) + } + return res +} diff --git a/field/babybear/fft/kernel_amd64.go b/field/babybear/fft/kernel_amd64.go index c555efe1d..f74883271 100644 --- a/field/babybear/fft/kernel_amd64.go +++ b/field/babybear/fft/kernel_amd64.go @@ -9,6 +9,7 @@ package fft import ( "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" "github.com/consensys/gnark-crypto/utils/cpu" ) @@ -65,3 +66,35 @@ func kerDITNP_256(a []babybear.Element, twiddles [][]babybear.Element, stage int } kerDITNP_256_avx512(a, twiddles, stage) } + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + //todo: use AVX512 +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + //todo: use AVX512 +} + +func kerDIFNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { + if !cpu.SupportAVX512 { + kerDIFNP_256genericExt(a, twiddles, stage) + return + } + //todo: use AVX512 +} + +func kerDITNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { + if !cpu.SupportAVX512 { + kerDITNP_256genericExt(a, twiddles, stage) + return + } + //todo: use AVX512 +} diff --git a/field/babybear/fft/kernel_purego.go b/field/babybear/fft/kernel_purego.go index 2d04cd8b9..a8e2167d8 100644 --- a/field/babybear/fft/kernel_purego.go +++ b/field/babybear/fft/kernel_purego.go @@ -9,6 +9,7 @@ package fft import ( "github.com/consensys/gnark-crypto/field/babybear" + fext "github.com/consensys/gnark-crypto/field/babybear/extensions" ) func innerDIFWithTwiddles(a []babybear.Element, twiddles []babybear.Element, start, end, m int) { @@ -25,3 +26,18 @@ func kerDIFNP_256(a []babybear.Element, twiddles [][]babybear.Element, stage int func kerDITNP_256(a []babybear.Element, twiddles [][]babybear.Element, stage int) { kerDITNP_256generic(a, twiddles, stage) } + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []babybear.Element, start, end, m int) { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func kerDIFNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { + kerDIFNP_256genericExt(a, twiddles, stage) +} +func kerDITNP_256Ext(a []fext.E4, twiddles [][]babybear.Element, stage int) { + kerDITNP_256genericExt(a, twiddles, stage) +} diff --git a/field/generator/generator_fft.go b/field/generator/generator_fft.go index 386ca43d8..a924f79fb 100644 --- a/field/generator/generator_fft.go +++ b/field/generator/generator_fft.go @@ -58,7 +58,10 @@ func generateFFT(F *config.Field, fft *config.FFT, outputDir string) error { {File: filepath.Join(outputDir, "bitreverse.go"), Templates: []string{"bitreverse.go.tmpl"}}, {File: filepath.Join(outputDir, "options.go"), Templates: []string{"options.go.tmpl"}}, } - + if F.F31 { + entries = append(entries, bavard.Entry{File: filepath.Join(outputDir, "fftext_test.go"), Templates: []string{"tests/fftext.go.tmpl"}}) + entries = append(entries, bavard.Entry{File: filepath.Join(outputDir, "fftext.go"), Templates: []string{"fftext.go.tmpl"}}) + } if data.HasASMKernel { data.Q = F.Q[0] data.QInvNeg = F.QInverse[0] diff --git a/field/generator/internal/templates/extensions/e4.go.tmpl b/field/generator/internal/templates/extensions/e4.go.tmpl index 50c56fee5..0671f9683 100644 --- a/field/generator/internal/templates/extensions/e4.go.tmpl +++ b/field/generator/internal/templates/extensions/e4.go.tmpl @@ -366,4 +366,13 @@ func MulAccE4(alpha *E4, scale []fr.Element, res []E4) { mulAccE4_avx512(alpha, &scale[0], &res[0], uint64(N)) } -{{- end}} \ No newline at end of file +{{- end}} + +// Butterfly computes the butterfly operation on two E4 elements +func Butterfly(a, b *E4) { + fr.Butterfly(&a.B0.A0, &b.B0.A0) + fr.Butterfly(&a.B0.A1, &b.B0.A1) + + fr.Butterfly(&a.B1.A0, &b.B1.A0) + fr.Butterfly(&a.B1.A1, &b.B1.A1) +} diff --git a/field/generator/internal/templates/fft/bitreverse.go.tmpl b/field/generator/internal/templates/fft/bitreverse.go.tmpl index c99e3451f..d917c2faf 100644 --- a/field/generator/internal/templates/fft/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/bitreverse.go.tmpl @@ -6,11 +6,22 @@ import ( {{- end}} "{{ .FieldPackagePath }}" + {{- if .F31}} + fext "{{ .FieldPackagePath }}/extensions" + {{- end}} ) +{{- if .F31}} +type SmallField interface { + {{ .FF }}.Element | fext.E4 +} +{{- end}} + // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func BitReverse(v []{{ .FF }}.Element) { +{{ if .F31}}func BitReverse[T SmallField](v []T) { +{{- else}}func BitReverse(v []{{ .FF }}.Element) { +{{- end}} n := uint64(len(v)) if bits.OnesCount64(n) != 1 { panic("len(a) must be a power of 2") @@ -29,7 +40,9 @@ func BitReverse(v []{{ .FF }}.Element) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func bitReverseNaive(v []{{ .FF }}.Element) { +{{ if .F31}}func bitReverseNaive[T SmallField](v []T) { +{{- else}}func bitReverseNaive(v []{{ .FF }}.Element) { +{{- end}} n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) diff --git a/field/generator/internal/templates/fft/fftext.go.tmpl b/field/generator/internal/templates/fft/fftext.go.tmpl new file mode 100644 index 000000000..52cb56c24 --- /dev/null +++ b/field/generator/internal/templates/fft/fftext.go.tmpl @@ -0,0 +1,451 @@ +import ( + "math/big" + "math/bits" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/internal/parallel" + + "{{ .FieldPackagePath }}" + fext "{{ .FieldPackagePath }}/extensions" +) + + +// FFTExt computes the discrete Fourier transform of a slice of extension field elements. +// Coefficients and evaluations are extension field elements. +// The root of unity domain is the same as FFT. +func (domain *Domain) FFTExt(a []fext.E4, decimation Decimation, opts ...Option) { + + opt := fftOptions(opts) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } + + // if coset != 0, scale by coset table + if opt.coset { + + if decimation == DIT { + + // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]{{ .FF }}.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].MulByElement(&a[i], &cosetTable[irev]) + } + }, opt.nbTasks) + } else { + + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at {{ .FF }}.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + + } + } + + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + if nbStages-twiddlesStartStage > 0 { + twiddles = make([][]{{ .FF }}.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]{{ .FF }}.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1<> nn) + a[i].MulByElement(&a[i], &cosetTableInv[irev]). + MulByElement(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) + +} + + +func difFFTExt(a []fext.E4, w {{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + {{- range $ki, $klog2 := $.Kernels}} + {{- if ne $ki 0}} else {{- end}} if n == 1 << {{$klog2}} { + {{- $ksize := shl 1 $klog2}} + kerDIFNP_{{$ksize}}Ext(a, twiddles, stage-twiddlesStartStage) + return + } + {{- end }} + } + m := n >> 1 + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at {{ .FF }}.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddlesExt(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) + } else { + {{- if .HasASMKernel}} + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + {{- else}} + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks / (1 << (stage))) + } else { + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } + {{- end}} + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + +} + +func innerDIFWithTwiddlesGenericExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddlesExt(a []fext.E4, at, w {{ .FF }}.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFTExt(a []fext.E4, w {{ .FF }}.Element, twiddles [][]{{ .FF }}.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + {{- range $ki, $klog2 := $.Kernels}} + {{- if ne $ki 0}} else {{- end}} if n == 1 << {{$klog2}} { + {{- $ksize := shl 1 $klog2}} + kerDITNP_{{$ksize}}Ext(a, twiddles, stage-twiddlesStartStage) + return + } + {{- end }} + } + + m := n >> 1 + + nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) + + if stage < maxSplits { + // that's the only time we fire go routines + chDone := make(chan struct{}, 1) + + go ditFFTExt(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFTExt(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at {{ .FF }}.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + + } else { + innerDITWithoutTwiddlesExt(a, w, w, 0, m, m) + } + return + } + {{- if .HasASMKernel}} + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + {{- else}} + if parallelButterfly { + parallel.Execute(m, func(start, end int) { + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], start, end, m) + }, nbTasks / (1 << (stage))) + } else { + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } + {{- end}} +} + + +func innerDITWithTwiddlesGenericExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + fext.Butterfly(&a[i], &a[i+m]) + } +} + +func innerDITWithoutTwiddlesExt(a []fext.E4, at, w {{ .FF }}.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &at) + fext.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } +} + + +{{range $ki, $klog2 := $.Kernels}} + {{$ksize := shl 1 $klog2}} + {{genKernel $.FF $ksize $klog2}} +{{end}} + +{{define "genKernel FF sizeKernel sizeKernelLog2"}} + +func kerDIFNP_{{.sizeKernel}}genericExt(a []fext.E4, twiddles [][]{{ .FF }}.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + + {{ $n := shl 1 .sizeKernelLog2}} + {{ $m := div $n 2}} + {{ $split := 1}} + {{- range $step := iterate 0 .sizeKernelLog2}} + {{- $offset := 0}} + + {{- $bound := mul $split $n}} + {{- if eq $bound $n}} + innerDIFWithTwiddlesGenericExt(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- else}} + for offset := 0; offset < {{$bound}}; offset += {{$n}} { + {{- if eq $m 1}} + fext.Butterfly(&a[offset], &a[offset+1]) + {{- else}} + innerDIFWithTwiddlesGenericExt(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- end}} + } + {{- end}} + + {{- $n = div $n 2}} + {{- $m = div $n 2}} + {{- $split = mul $split 2}} + {{- end}} +} + +func kerDITNP_{{.sizeKernel}}genericExt(a []fext.E4, twiddles [][]{{ .FF }}.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + {{ $n := 2}} + {{ $m := div $n 2}} + {{ $split := div (shl 1 .sizeKernelLog2) 2}} + {{- range $step := reverse (iterate 0 .sizeKernelLog2)}} + {{- $offset := 0}} + + {{- $bound := mul $split $n}} + {{- if eq $bound $n}} + innerDITWithTwiddlesGenericExt(a[:{{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- else}} + for offset := 0; offset < {{$bound}}; offset += {{$n}} { + {{- if eq $m 1}} + fext.Butterfly(&a[offset], &a[offset+1]) + {{- else}} + innerDITWithTwiddlesGenericExt(a[offset:offset + {{$n}}], twiddles[stage + {{$step}}], 0, {{$m}}, {{$m}}) + {{- end}} + } + {{- end}} + + {{- $n = mul $n 2}} + {{- $m = div $n 2}} + {{- $split = div $split 2}} + {{- end}} +} + +{{end}} + diff --git a/field/generator/internal/templates/fft/kernel.amd64.go.tmpl b/field/generator/internal/templates/fft/kernel.amd64.go.tmpl index 722961203..cbb1bbfd4 100644 --- a/field/generator/internal/templates/fft/kernel.amd64.go.tmpl +++ b/field/generator/internal/templates/fft/kernel.amd64.go.tmpl @@ -1,6 +1,8 @@ import ( "github.com/consensys/gnark-crypto/utils/cpu" "{{ .FieldPackagePath }}" + fext "{{ .FieldPackagePath }}/extensions" + ) @@ -63,3 +65,39 @@ func kerDITNP_{{$ksize}}(a []{{ $.FF }}.Element, twiddles [][]{{ $.FF }}.Element } {{end}} +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + //todo: use AVX512 +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + //todo: use AVX512 +} + + +{{range $ki, $klog2 := $.Kernels}} + {{- $ksize := shl 1 $klog2}} + +func kerDIFNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { + if !cpu.SupportAVX512 { + kerDIFNP_{{$ksize}}genericExt(a, twiddles, stage) + return + } + //todo: use AVX512 +} + +func kerDITNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { + if !cpu.SupportAVX512 { + kerDITNP_{{$ksize}}genericExt(a, twiddles, stage) + return + } + //todo: use AVX512 +} +{{end}} \ No newline at end of file diff --git a/field/generator/internal/templates/fft/kernel.purego.go.tmpl b/field/generator/internal/templates/fft/kernel.purego.go.tmpl index 99d75c3d9..24eff28b9 100644 --- a/field/generator/internal/templates/fft/kernel.purego.go.tmpl +++ b/field/generator/internal/templates/fft/kernel.purego.go.tmpl @@ -1,5 +1,8 @@ import ( "{{ .FieldPackagePath }}" + {{- if .F31}} + fext "{{ .FieldPackagePath }}/extensions" + {{- end}} ) func innerDIFWithTwiddles(a []{{ .FF }}.Element, twiddles []{{ .FF }}.Element, start, end, m int) { @@ -18,4 +21,23 @@ func kerDIFNP_{{$ksize}}(a []{{ $.FF }}.Element, twiddles [][]{{ $.FF }}.Element func kerDITNP_{{$ksize}}(a []{{ $.FF }}.Element, twiddles [][]{{ $.FF }}.Element, stage int) { kerDITNP_{{$ksize}}generic(a, twiddles, stage) } -{{end}} \ No newline at end of file +{{end}} + +{{- if .F31}} +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []{{ .FF }}.Element, start, end, m int) { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) +} +{{range $ki, $klog2 := $.Kernels}} + {{- $ksize := shl 1 $klog2}} +func kerDIFNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { + kerDIFNP_{{$ksize}}genericExt(a, twiddles, stage) +} +func kerDITNP_{{$ksize}}Ext(a []fext.E4, twiddles [][]{{ $.FF }}.Element, stage int) { + kerDITNP_{{$ksize}}genericExt(a, twiddles, stage) +} +{{end}} +{{- end}} \ No newline at end of file diff --git a/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl b/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl index dd67bcada..f42188990 100644 --- a/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl +++ b/field/generator/internal/templates/fft/tests/bitreverse.go.tmpl @@ -3,9 +3,12 @@ import ( "testing" "{{ .FieldPackagePath }}" + {{- if .F31}} + fext "{{ .FieldPackagePath }}/extensions" + {{- end}} ) - +{{- if not .F31}} type bitReverseVariant struct { name string buf []{{ .FF }}.Element @@ -98,3 +101,176 @@ func BenchmarkBitReverse(b *testing.B) { } } } + +{{- else}} +type bitReverseVariant[T SmallField] struct { + name string + buf []T + fn func([]T) +} + +const maxSizeBitReverse = 1 << 23 +var {{ .FF }}BitReverse = []bitReverseVariant[{{ .FF }}.Element]{ + {name: "bitReverseNaive", buf: make([]{{ .FF }}.Element, maxSizeBitReverse), fn: bitReverseNaive[{{ .FF }}.Element]}, + {name: "BitReverse", buf: make([]{{ .FF }}.Element, maxSizeBitReverse), fn: BitReverse[{{ .FF }}.Element]}, +} + + +func TestElementBitReverse(t *testing.T) { + + // generate a random []{{ .FF }}.Element array of size 2**20 + pol := make([]{{ .FF }}.Element, maxSizeBitReverse) + one := {{ .FF }}.One() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range {{ .FF }}BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range {{ .FF }}BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len({{ .FF }}BitReverse); j++ { + if !{{ .FF }}BitReverse[0].buf[i].Equal(&{{ .FF }}BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", {{ .FF }}BitReverse[0].name, {{ .FF }}BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range {{ .FF }}BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len({{ .FF }}BitReverse); j++ { + if !{{ .FF }}BitReverse[0].buf[i].Equal(&{{ .FF }}BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", {{ .FF }}BitReverse[0].name, {{ .FF }}BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkElementBitReverse(b *testing.B) { + // generate a random []{{ .FF }}.Element array of size 2**22 + pol := make([]{{ .FF }}.Element, maxSizeBitReverse) + one := {{ .FF }}.One() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range {{ .FF }}BitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range {{ .FF }}BitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} + +var e4BitReverse = []bitReverseVariant[fext.E4]{ + {name: "bitReverseNaive", buf: make([]fext.E4, maxSizeBitReverse), fn: bitReverseNaive[fext.E4]}, + {name: "BitReverse", buf: make([]fext.E4, maxSizeBitReverse), fn: BitReverse[fext.E4]}, +} + + +func TestE4BitReverse(t *testing.T) { + + // generate a random []{{ .FF }}.Element array of size 2**20 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkE4BitReverse(b *testing.B) { + // generate a random []E4 array of size 2**22 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range e4BitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} +{{- end}} \ No newline at end of file diff --git a/field/generator/internal/templates/fft/tests/fftext.go.tmpl b/field/generator/internal/templates/fft/tests/fftext.go.tmpl new file mode 100644 index 000000000..e57c4f147 --- /dev/null +++ b/field/generator/internal/templates/fft/tests/fftext.go.tmpl @@ -0,0 +1,399 @@ +import ( + "math/big" + "testing" + "strconv" + + + "{{ .FieldPackagePath }}" + fext "{{ .FieldPackagePath }}/extensions" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" + "github.com/leanovate/gopter/gen" + + "fmt" + {{- if .F31}} + "github.com/stretchr/testify/require" + "encoding/binary" + "math/rand/v2" + {{- end}} +) + +func TestFFTExt(t *testing.T) { + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 6 + properties := gopter.NewProperties(parameters) + + for maxSize := 2; maxSize <= 1<<10; maxSize <<= 1 { + + domainWithPrecompute := NewDomain(uint64(maxSize)) + domainWithoutPrecompute := NewDomain(uint64(maxSize), WithoutPrecompute()) + + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF) + BitReverse(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF, OnCoset()) + BitReverse(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverse(pol) + domain.FFTExt(pol, DIT) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverse(pol) + domain.FFTExt(pol, DIT) + domain.FFTInverseExt(pol, DIF) + BitReverse(pol) + + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) + + for nbCosets := 2; nbCosets < 5; nbCosets++ { + properties.Property(fmt.Sprintf("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on %d cosets", nbCosets), prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + check := true + + for i := 1; i <= nbCosets; i++ { + + BitReverse(pol) + domain.FFTExt(pol, DIT, OnCoset()) + domain.FFTInverseExt(pol, DIF, OnCoset()) + BitReverse(pol) + + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + } + + return check + }, + )) + } + + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF) + domain.FFTExt(pol, DIT) + + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) + + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF, OnCoset()) + domain.FFTExt(pol, DIT, OnCoset()) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + // compute with nbTasks == 1 + domain.FFTInverseExt(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFTExt(pol, DIT, OnCoset(), WithNbTasks(1)) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + return true + }, + )) + } + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } + +} +{{- if .F31}} + +func randElementExt(rng *rand.Rand) fext.E4 { + var v fext.E4 + v.B0.A0 = {{ .FF }}.Element{rng.Uint32N({{.Q}})} + v.B0.A1 = {{ .FF }}.Element{rng.Uint32N({{.Q}})} + v.B1.A0 = {{ .FF }}.Element{rng.Uint32N({{.Q}})} + v.B1.A1 = {{ .FF }}.Element{rng.Uint32N({{.Q}})} + return v +} + +func FuzzFFTExt(f *testing.F) { + f.Fuzz(func(t *testing.T, domainSize uint16, rngSeed int64) { + if domainSize > (1 << 13) { + t.Skip("domain size too large") + } + if domainSize < 2 { + t.Skip("domain size too small") + } + + domain := NewDomain(uint64(domainSize)) + + var seed [32]byte + binary.PutVarint(seed[:], rngSeed) + // #nosec G404 -- fuzz does not require a cryptographic PRNG + rng := rand.New(rand.NewChaCha8(seed)) + + cardinality := domain.Cardinality + + // we just check that FFT-1(FFT(pol)) == pol + a, b := make([]fext.E4, cardinality), make([]fext.E4, cardinality) + for i := 0; i < int(cardinality); i++ { + a[i] = randElementExt(rng) + } + copy(b, a) + + domain.FFTInverseExt(a, DIF) + domain.FFTExt(a, DIT) + + assert := require.New(t) + for i := 0; i < int(cardinality); i++ { + assert.True(a[i].Equal(&b[i]), "FFT-1(FFT(pol)) != pol at index %d", i) + } + }) +} + +{{- end}} + +// -------------------------------------------------------------------- +// benches + +func BenchmarkFFTExt(b *testing.B) { + + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + for i := 8; i < 20; i++ { + sizeDomain := 1 << i + b.Run("fft 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT) + } + }) + b.Run("fft 2**"+strconv.Itoa(i)+"bits (coset)", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT, OnCoset()) + } + }) + } + +} + +func BenchmarkFFTDITCosetReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIT, OnCoset()) + } +} + +func BenchmarkFFTDITReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < 1; j++ { + domain.FFTExt(pol, DIT) + } +} + +func BenchmarkFFTDIFReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} +func BenchmarkFFTDIFReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} + +func evaluatePolynomialExt(pol []fext.E4, val {{ .FF }}.Element) fext.E4 { + var res, tmp fext.E4 + var acc {{ .FF }}.Element + res.Set(&pol[0]) + acc.Set(&val) + for i := 1; i < len(pol); i++ { + tmp.MulByElement(&pol[i], &acc) + res.Add(&res, &tmp) + acc.Mul(&acc, &val) + } + return res +} \ No newline at end of file diff --git a/field/koalabear/extensions/e4.go b/field/koalabear/extensions/e4.go index af45fedc8..fcab0827b 100644 --- a/field/koalabear/extensions/e4.go +++ b/field/koalabear/extensions/e4.go @@ -369,3 +369,12 @@ func MulAccE4(alpha *E4, scale []fr.Element, res []E4) { mulAccE4_avx512(alpha, &scale[0], &res[0], uint64(N)) } + +// Butterfly computes the butterfly operation on two E4 elements +func Butterfly(a, b *E4) { + fr.Butterfly(&a.B0.A0, &b.B0.A0) + fr.Butterfly(&a.B0.A1, &b.B0.A1) + + fr.Butterfly(&a.B1.A0, &b.B1.A0) + fr.Butterfly(&a.B1.A1, &b.B1.A1) +} diff --git a/field/koalabear/fft/bitreverse.go b/field/koalabear/fft/bitreverse.go index b1e953533..7a57ac0ca 100644 --- a/field/koalabear/fft/bitreverse.go +++ b/field/koalabear/fft/bitreverse.go @@ -9,11 +9,16 @@ import ( "math/bits" "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" ) +type SmallField interface { + koalabear.Element | fext.E4 +} + // BitReverse applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func BitReverse(v []koalabear.Element) { +func BitReverse[T SmallField](v []T) { n := uint64(len(v)) if bits.OnesCount64(n) != 1 { panic("len(a) must be a power of 2") @@ -24,7 +29,7 @@ func BitReverse(v []koalabear.Element) { // bitReverseNaive applies the bit-reversal permutation to v. // len(v) must be a power of 2 -func bitReverseNaive(v []koalabear.Element) { +func bitReverseNaive[T SmallField](v []T) { n := uint64(len(v)) nn := uint64(64 - bits.TrailingZeros64(n)) diff --git a/field/koalabear/fft/bitreverse_test.go b/field/koalabear/fft/bitreverse_test.go index d1dd3cfcf..31f2ad018 100644 --- a/field/koalabear/fft/bitreverse_test.go +++ b/field/koalabear/fft/bitreverse_test.go @@ -10,22 +10,23 @@ import ( "testing" "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" ) -type bitReverseVariant struct { +type bitReverseVariant[T SmallField] struct { name string - buf []koalabear.Element - fn func([]koalabear.Element) + buf []T + fn func([]T) } const maxSizeBitReverse = 1 << 23 -var bitReverse = []bitReverseVariant{ - {name: "bitReverseNaive", buf: make([]koalabear.Element, maxSizeBitReverse), fn: bitReverseNaive}, - {name: "BitReverse", buf: make([]koalabear.Element, maxSizeBitReverse), fn: BitReverse}, +var koalabearBitReverse = []bitReverseVariant[koalabear.Element]{ + {name: "bitReverseNaive", buf: make([]koalabear.Element, maxSizeBitReverse), fn: bitReverseNaive[koalabear.Element]}, + {name: "BitReverse", buf: make([]koalabear.Element, maxSizeBitReverse), fn: BitReverse[koalabear.Element]}, } -func TestBitReverse(t *testing.T) { +func TestElementBitReverse(t *testing.T) { // generate a random []koalabear.Element array of size 2**20 pol := make([]koalabear.Element, maxSizeBitReverse) @@ -39,33 +40,33 @@ func TestBitReverse(t *testing.T) { for size := 2; size <= maxSizeBitReverse; size <<= 1 { // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { copy(data.buf, pol[:size]) } // compute bit reverse shuffling - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { data.fn(data.buf[:size]) } // all bitReverse.buf should hold the same result for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(koalabearBitReverse); j++ { + if !koalabearBitReverse[0].buf[i].Equal(&koalabearBitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", koalabearBitReverse[0].name, koalabearBitReverse[j].name) } } } // bitReverse back should be identity - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { data.fn(data.buf[:size]) } for i := 0; i < size; i++ { - for j := 1; j < len(bitReverse); j++ { - if !bitReverse[0].buf[i].Equal(&bitReverse[j].buf[i]) { - t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", bitReverse[0].name, bitReverse[j].name) + for j := 1; j < len(koalabearBitReverse); j++ { + if !koalabearBitReverse[0].buf[i].Equal(&koalabearBitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", koalabearBitReverse[0].name, koalabearBitReverse[j].name) } } } @@ -73,7 +74,7 @@ func TestBitReverse(t *testing.T) { } -func BenchmarkBitReverse(b *testing.B) { +func BenchmarkElementBitReverse(b *testing.B) { // generate a random []koalabear.Element array of size 2**22 pol := make([]koalabear.Element, maxSizeBitReverse) one := koalabear.One() @@ -83,13 +84,95 @@ func BenchmarkBitReverse(b *testing.B) { } // copy pol into the buffers - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { copy(data.buf, pol[:maxSizeBitReverse]) } // benchmark for each size, each bitReverse function for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { - for _, data := range bitReverse { + for _, data := range koalabearBitReverse { + b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + data.fn(data.buf[:size]) + } + }) + } + } +} + +var e4BitReverse = []bitReverseVariant[fext.E4]{ + {name: "bitReverseNaive", buf: make([]fext.E4, maxSizeBitReverse), fn: bitReverseNaive[fext.E4]}, + {name: "BitReverse", buf: make([]fext.E4, maxSizeBitReverse), fn: BitReverse[fext.E4]}, +} + +func TestE4BitReverse(t *testing.T) { + + // generate a random []koalabear.Element array of size 2**20 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // for each size, check that all the bitReverse functions fn compute the same result. + for size := 2; size <= maxSizeBitReverse; size <<= 1 { + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:size]) + } + + // compute bit reverse shuffling + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + // all bitReverse.buf should hold the same result + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + + // bitReverse back should be identity + for _, data := range e4BitReverse { + data.fn(data.buf[:size]) + } + + for i := 0; i < size; i++ { + for j := 1; j < len(e4BitReverse); j++ { + if !e4BitReverse[0].buf[i].Equal(&e4BitReverse[j].buf[i]) { + t.Fatalf("(fn-1) bitReverse %s and %s do not compute the same result", e4BitReverse[0].name, e4BitReverse[j].name) + } + } + } + } + +} + +func BenchmarkE4BitReverse(b *testing.B) { + // generate a random []E4 array of size 2**22 + pol := make([]fext.E4, maxSizeBitReverse) + var one fext.E4 + one.SetOne() + pol[0].MustSetRandom() + for i := 1; i < maxSizeBitReverse; i++ { + pol[i].Add(&pol[i-1], &one) + } + + // copy pol into the buffers + for _, data := range e4BitReverse { + copy(data.buf, pol[:maxSizeBitReverse]) + } + + // benchmark for each size, each bitReverse function + for size := 1 << 18; size <= maxSizeBitReverse; size <<= 1 { + for _, data := range e4BitReverse { b.Run(fmt.Sprintf("name=%s/size=%d", data.name, size), func(b *testing.B) { b.ResetTimer() for j := 0; j < b.N; j++ { diff --git a/field/koalabear/fft/fftext.go b/field/koalabear/fft/fftext.go new file mode 100644 index 000000000..5d57b9f75 --- /dev/null +++ b/field/koalabear/fft/fftext.go @@ -0,0 +1,407 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "math/big" + "math/bits" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/internal/parallel" + + "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" +) + +// FFTExt computes the discrete Fourier transform of a slice of extension field elements. +// Coefficients and evaluations are extension field elements. +// The root of unity domain is the same as FFT. +func (domain *Domain) FFTExt(a []fext.E4, decimation Decimation, opts ...Option) { + + opt := fftOptions(opts) + + // find the stage where we should stop spawning go routines in our recursive calls + // (ie when we have as many go routines running as we have available CPUs) + maxSplits := bits.TrailingZeros64(ecc.NextPowerOfTwo(uint64(opt.nbTasks))) + if opt.nbTasks == 1 { + maxSplits = -1 + } + + // if coset != 0, scale by coset table + if opt.coset { + + if decimation == DIT { + + // scale by coset table (in bit reversed order) + cosetTable := domain.cosetTable + if !domain.withPrecompute { + // we need to build the full table or do a bit reverse dance. + cosetTable = make([]koalabear.Element, len(a)) + BuildExpTable(domain.FrMultiplicativeGen, cosetTable) + } + parallel.Execute(len(a), func(start, end int) { + n := uint64(len(a)) + nn := uint64(64 - bits.TrailingZeros64(n)) + for i := start; i < end; i++ { + irev := int(bits.Reverse64(uint64(i)) >> nn) + a[i].MulByElement(&a[i], &cosetTable[irev]) + } + }, opt.nbTasks) + } else { + + if domain.withPrecompute { + parallel.Execute(len(a), func(start, end int) { + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &domain.cosetTable[i]) + } + }, opt.nbTasks) + } else { + c := domain.FrMultiplicativeGen + parallel.Execute(len(a), func(start, end int) { + var at koalabear.Element + at.Exp(c, big.NewInt(int64(start))) + for i := start; i < end; i++ { + a[i].MulByElement(&a[i], &at) + at.Mul(&at, &c) + } + }, opt.nbTasks) + } + + } + } + + twiddles := domain.twiddles + twiddlesStartStage := 0 + if !domain.withPrecompute { + twiddlesStartStage = 3 + nbStages := int(bits.TrailingZeros64(domain.Cardinality)) + if nbStages-twiddlesStartStage > 0 { + twiddles = make([][]koalabear.Element, nbStages-twiddlesStartStage) + w := domain.Generator + w.Exp(w, big.NewInt(int64(1< 0 { + twiddlesInv = make([][]koalabear.Element, nbStages-twiddlesStartStage) + w := domain.GeneratorInv + w.Exp(w, big.NewInt(int64(1<> nn) + a[i].MulByElement(&a[i], &cosetTableInv[irev]). + MulByElement(&a[i], &domain.CardinalityInv) + } + }, opt.nbTasks) + +} + +func difFFTExt(a []fext.E4, w koalabear.Element, twiddles [][]koalabear.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + if n == 1<<8 { + kerDIFNP_256Ext(a, twiddles, stage-twiddlesStartStage) + return + } + } + m := n >> 1 + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at koalabear.Element + at.Exp(w, big.NewInt(int64(start))) + innerDIFWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + } else { + innerDIFWithoutTwiddlesExt(a, w, w, 0, m, m) + } + // compute next twiddle + w.Square(&w) + } else { + innerDIFWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) + } + + if m == 1 { + return + } + + nextStage := stage + 1 + if stage < maxSplits { + chDone := make(chan struct{}, 1) + go difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + difFFTExt(a[0:m], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + difFFTExt(a[m:n], w, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + +} + +func innerDIFWithTwiddlesGenericExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + } +} + +func innerDIFWithoutTwiddlesExt(a []fext.E4, at, w koalabear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + fext.Butterfly(&a[i], &a[i+m]) + a[i+m].MulByElement(&a[i+m], &at) + at.Mul(&at, &w) + } +} + +func ditFFTExt(a []fext.E4, w koalabear.Element, twiddles [][]koalabear.Element, twiddlesStartStage, stage, maxSplits int, chDone chan struct{}, nbTasks int) { + if chDone != nil { + defer close(chDone) + } + n := len(a) + if n == 1 { + return + } else if stage >= twiddlesStartStage { + if n == 1<<8 { + kerDITNP_256Ext(a, twiddles, stage-twiddlesStartStage) + return + } + } + + m := n >> 1 + + nextStage := stage + 1 + nextW := w + nextW.Square(&nextW) + + if stage < maxSplits { + // that's the only time we fire go routines + chDone := make(chan struct{}, 1) + + go ditFFTExt(a[m:], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, chDone, nbTasks) + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + <-chDone + } else { + + ditFFTExt(a[0:m], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + ditFFTExt(a[m:n], nextW, twiddles, twiddlesStartStage, nextStage, maxSplits, nil, nbTasks) + } + + parallelButterfly := (m > butterflyThreshold) && (stage < maxSplits) + + if stage < twiddlesStartStage { + // we need to compute the twiddles for this stage on the fly. + if parallelButterfly { + w := w + parallel.Execute(m, func(start, end int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + var at koalabear.Element + at.Exp(w, big.NewInt(int64(start))) + innerDITWithoutTwiddlesExt(a, at, w, start, end, m) + }, nbTasks/(1<<(stage))) // 1 << stage == estimated used CPUs + + } else { + innerDITWithoutTwiddlesExt(a, w, w, 0, m, m) + } + return + } + innerDITWithTwiddlesExt(a, twiddles[stage-twiddlesStartStage], 0, m, m) +} + +func innerDITWithTwiddlesGenericExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &twiddles[i]) + fext.Butterfly(&a[i], &a[i+m]) + } +} + +func innerDITWithoutTwiddlesExt(a []fext.E4, at, w koalabear.Element, start, end, m int) { + if start == 0 { + fext.Butterfly(&a[0], &a[m]) + start++ + } + for i := start; i < end; i++ { + a[i+m].MulByElement(&a[i+m], &at) + fext.Butterfly(&a[i], &a[i+m]) + at.Mul(&at, &w) + } +} + +func kerDIFNP_256genericExt(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + + innerDIFWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) + for offset := 0; offset < 256; offset += 128 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + for offset := 0; offset < 256; offset += 64 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 32 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 16 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 8 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 4 { + innerDIFWithTwiddlesGenericExt(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 2 { + fext.Butterfly(&a[offset], &a[offset+1]) + } +} + +func kerDITNP_256genericExt(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + // code unrolled & generated by internal/generator/fft/template/fftext.go.tmpl + + for offset := 0; offset < 256; offset += 2 { + fext.Butterfly(&a[offset], &a[offset+1]) + } + for offset := 0; offset < 256; offset += 4 { + innerDITWithTwiddlesGenericExt(a[offset:offset+4], twiddles[stage+6], 0, 2, 2) + } + for offset := 0; offset < 256; offset += 8 { + innerDITWithTwiddlesGenericExt(a[offset:offset+8], twiddles[stage+5], 0, 4, 4) + } + for offset := 0; offset < 256; offset += 16 { + innerDITWithTwiddlesGenericExt(a[offset:offset+16], twiddles[stage+4], 0, 8, 8) + } + for offset := 0; offset < 256; offset += 32 { + innerDITWithTwiddlesGenericExt(a[offset:offset+32], twiddles[stage+3], 0, 16, 16) + } + for offset := 0; offset < 256; offset += 64 { + innerDITWithTwiddlesGenericExt(a[offset:offset+64], twiddles[stage+2], 0, 32, 32) + } + for offset := 0; offset < 256; offset += 128 { + innerDITWithTwiddlesGenericExt(a[offset:offset+128], twiddles[stage+1], 0, 64, 64) + } + innerDITWithTwiddlesGenericExt(a[:256], twiddles[stage+0], 0, 128, 128) +} diff --git a/field/koalabear/fft/fftext_test.go b/field/koalabear/fft/fftext_test.go new file mode 100644 index 000000000..3bada0049 --- /dev/null +++ b/field/koalabear/fft/fftext_test.go @@ -0,0 +1,400 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fft + +import ( + "math/big" + "strconv" + "testing" + + "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "encoding/binary" + "fmt" + "github.com/stretchr/testify/require" + "math/rand/v2" +) + +func TestFFTExt(t *testing.T) { + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 6 + properties := gopter.NewProperties(parameters) + + for maxSize := 2; maxSize <= 1<<10; maxSize <<= 1 { + + domainWithPrecompute := NewDomain(uint64(maxSize)) + domainWithoutPrecompute := NewDomain(uint64(maxSize), WithoutPrecompute()) + + for domainName, domain := range map[string]*Domain{ + "with precompute": domainWithPrecompute, + "without precompute": domainWithoutPrecompute, + } { + domainName := domainName + domain := domain + t.Logf("domain: %s", domainName) + properties.Property("DIF FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF) + BitReverse(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIF FFT on cosets should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTExt(pol, DIF, OnCoset()) + BitReverse(pol) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))). + Mul(&sample, &domain.FrMultiplicativeGen) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("DIT FFT should be consistent with dual basis", prop.ForAll( + + // checks that a random evaluation of a dual function eval(gen**ithpower) is consistent with the FFT result + func(ithpower int) bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverse(pol) + domain.FFTExt(pol, DIT) + + sample := domain.Generator + sample.Exp(sample, big.NewInt(int64(ithpower))) + + eval := evaluatePolynomialExt(backupPol, sample) + + return eval.Equal(&pol[ithpower]) + + }, + gen.IntRange(0, maxSize-1), + )) + + properties.Property("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + BitReverse(pol) + domain.FFTExt(pol, DIT) + domain.FFTInverseExt(pol, DIF) + BitReverse(pol) + + check := true + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + return check + }, + )) + + for nbCosets := 2; nbCosets < 5; nbCosets++ { + properties.Property(fmt.Sprintf("bitReverse(DIF FFT(DIT FFT (bitReverse))))==id on %d cosets", nbCosets), prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + check := true + + for i := 1; i <= nbCosets; i++ { + + BitReverse(pol) + domain.FFTExt(pol, DIT, OnCoset()) + domain.FFTInverseExt(pol, DIF, OnCoset()) + BitReverse(pol) + + for i := 0; i < len(pol); i++ { + check = check && pol[i].Equal(&backupPol[i]) + } + } + + return check + }, + )) + } + + properties.Property("DIT FFT(DIF FFT)==id", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF) + domain.FFTExt(pol, DIT) + + check := true + for i := 0; i < len(pol); i++ { + check = check && (pol[i] == backupPol[i]) + } + return check + }, + )) + + properties.Property("DIT FFT(DIF FFT)==id on cosets", prop.ForAll( + + func() bool { + + pol := make([]fext.E4, maxSize) + backupPol := make([]fext.E4, maxSize) + + for i := 0; i < maxSize; i++ { + pol[i].MustSetRandom() + } + copy(backupPol, pol) + + domain.FFTInverseExt(pol, DIF, OnCoset()) + domain.FFTExt(pol, DIT, OnCoset()) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + // compute with nbTasks == 1 + domain.FFTInverseExt(pol, DIF, OnCoset(), WithNbTasks(1)) + domain.FFTExt(pol, DIT, OnCoset(), WithNbTasks(1)) + + for i := 0; i < len(pol); i++ { + if !(pol[i].Equal(&backupPol[i])) { + return false + } + } + + return true + }, + )) + } + properties.TestingRun(t, gopter.ConsoleReporter(false)) + } + +} + +func randElementExt(rng *rand.Rand) fext.E4 { + var v fext.E4 + v.B0.A0 = koalabear.Element{rng.Uint32N(2130706433)} + v.B0.A1 = koalabear.Element{rng.Uint32N(2130706433)} + v.B1.A0 = koalabear.Element{rng.Uint32N(2130706433)} + v.B1.A1 = koalabear.Element{rng.Uint32N(2130706433)} + return v +} + +func FuzzFFTExt(f *testing.F) { + f.Fuzz(func(t *testing.T, domainSize uint16, rngSeed int64) { + if domainSize > (1 << 13) { + t.Skip("domain size too large") + } + if domainSize < 2 { + t.Skip("domain size too small") + } + + domain := NewDomain(uint64(domainSize)) + + var seed [32]byte + binary.PutVarint(seed[:], rngSeed) + // #nosec G404 -- fuzz does not require a cryptographic PRNG + rng := rand.New(rand.NewChaCha8(seed)) + + cardinality := domain.Cardinality + + // we just check that FFT-1(FFT(pol)) == pol + a, b := make([]fext.E4, cardinality), make([]fext.E4, cardinality) + for i := 0; i < int(cardinality); i++ { + a[i] = randElementExt(rng) + } + copy(b, a) + + domain.FFTInverseExt(a, DIF) + domain.FFTExt(a, DIT) + + assert := require.New(t) + for i := 0; i < int(cardinality); i++ { + assert.True(a[i].Equal(&b[i]), "FFT-1(FFT(pol)) != pol at index %d", i) + } + }) +} + +// -------------------------------------------------------------------- +// benches + +func BenchmarkFFTExt(b *testing.B) { + + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + for i := 8; i < 20; i++ { + sizeDomain := 1 << i + b.Run("fft 2**"+strconv.Itoa(i)+"bits", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT) + } + }) + b.Run("fft 2**"+strconv.Itoa(i)+"bits (coset)", func(b *testing.B) { + domain := NewDomain(uint64(sizeDomain)) + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol[:sizeDomain], DIT, OnCoset()) + } + }) + } + +} + +func BenchmarkFFTDITCosetReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIT, OnCoset()) + } +} + +func BenchmarkFFTDITReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < 1; j++ { + domain.FFTExt(pol, DIT) + } +} + +func BenchmarkFFTDIFReferenceExt(b *testing.B) { + const maxSize = 1 << 20 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} +func BenchmarkFFTDIFReferenceSmallExt(b *testing.B) { + const maxSize = 1 << 9 + + pol := make([]fext.E4, maxSize) + pol[0].MustSetRandom() + for i := 1; i < maxSize; i++ { + pol[i] = pol[i-1] + } + + domain := NewDomain(maxSize) + + b.ResetTimer() + for j := 0; j < b.N; j++ { + domain.FFTExt(pol, DIF) + } +} + +func evaluatePolynomialExt(pol []fext.E4, val koalabear.Element) fext.E4 { + var res, tmp fext.E4 + var acc koalabear.Element + res.Set(&pol[0]) + acc.Set(&val) + for i := 1; i < len(pol); i++ { + tmp.MulByElement(&pol[i], &acc) + res.Add(&res, &tmp) + acc.Mul(&acc, &val) + } + return res +} diff --git a/field/koalabear/fft/kernel_amd64.go b/field/koalabear/fft/kernel_amd64.go index 7abce70f4..fb2410724 100644 --- a/field/koalabear/fft/kernel_amd64.go +++ b/field/koalabear/fft/kernel_amd64.go @@ -9,6 +9,7 @@ package fft import ( "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" "github.com/consensys/gnark-crypto/utils/cpu" ) @@ -65,3 +66,35 @@ func kerDITNP_256(a []koalabear.Element, twiddles [][]koalabear.Element, stage i } kerDITNP_256_avx512(a, twiddles, stage) } + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + //todo: use AVX512 +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + if !cpu.SupportAVX512 || m < 16 { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) + return + } + //todo: use AVX512 +} + +func kerDIFNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + if !cpu.SupportAVX512 { + kerDIFNP_256genericExt(a, twiddles, stage) + return + } + //todo: use AVX512 +} + +func kerDITNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + if !cpu.SupportAVX512 { + kerDITNP_256genericExt(a, twiddles, stage) + return + } + //todo: use AVX512 +} diff --git a/field/koalabear/fft/kernel_purego.go b/field/koalabear/fft/kernel_purego.go index c711d7d72..ac907352e 100644 --- a/field/koalabear/fft/kernel_purego.go +++ b/field/koalabear/fft/kernel_purego.go @@ -9,6 +9,7 @@ package fft import ( "github.com/consensys/gnark-crypto/field/koalabear" + fext "github.com/consensys/gnark-crypto/field/koalabear/extensions" ) func innerDIFWithTwiddles(a []koalabear.Element, twiddles []koalabear.Element, start, end, m int) { @@ -25,3 +26,18 @@ func kerDIFNP_256(a []koalabear.Element, twiddles [][]koalabear.Element, stage i func kerDITNP_256(a []koalabear.Element, twiddles [][]koalabear.Element, stage int) { kerDITNP_256generic(a, twiddles, stage) } + +func innerDIFWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + innerDIFWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func innerDITWithTwiddlesExt(a []fext.E4, twiddles []koalabear.Element, start, end, m int) { + innerDITWithTwiddlesGenericExt(a, twiddles, start, end, m) +} + +func kerDIFNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + kerDIFNP_256genericExt(a, twiddles, stage) +} +func kerDITNP_256Ext(a []fext.E4, twiddles [][]koalabear.Element, stage int) { + kerDITNP_256genericExt(a, twiddles, stage) +}