Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AVX512 GFNI processing #224

Merged
merged 10 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
with:
go-version: ${{ matrix.go-version }}

- name: CPU support
run: go install github.com/klauspost/cpuid/v2/cmd/cpuid@latest&&cpuid

- name: Checkout code
uses: actions/checkout@v2

Expand Down Expand Up @@ -71,6 +74,11 @@ jobs:
CGO_ENABLED: 1
run: go test -tags=noasm -cpu=4 -short -race -timeout 20m .

- name: Test Races, no gfni
env:
CGO_ENABLED: 1
run: go test -no-gfni -short -race

- name: Test Races, no avx512
env:
CGO_ENABLED: 1
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Using Go modules is recommended.

## 2022

* [GFNI](https://github.com/klauspost/reedsolomon/pull/224) support for amd64, for up to 3x faster processing.
* [Leopard GF8](https://github.com/klauspost/reedsolomon#leopard-gf8) mode added, for faster processing of medium shard counts.
* [Leopard GF16](https://github.com/klauspost/reedsolomon#leopard-compatible-gf16) mode added, for up to 65536 shards.
* [WithJerasureMatrix](https://pkg.go.dev/github.com/klauspost/reedsolomon?tab=doc#WithJerasureMatrix) allows constructing a [Jerasure](https://github.com/tsuraan/Jerasure) compatible matrix.
Expand Down Expand Up @@ -480,7 +481,8 @@ BenchmarkReconstruct50x20x1M-8 1364.35 4189.79 3.07x
BenchmarkReconstruct10x4x16M-8 1484.35 5779.53 3.89x
```

The performance on AVX512 has been accelerated for CPUs when available.
The package will use [GFNI](https://en.wikipedia.org/wiki/AVX-512#GFNI) instructions combined with AVX512 when these are available.
This further improves speed by up to 3x over AVX2 code paths.

## ARM64 NEON

Expand Down
1 change: 1 addition & 0 deletions _gen/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func main() {
}
data = bytes.ReplaceAll(data, []byte("\t// #"), []byte("#"))
data = bytes.ReplaceAll(data, []byte("\t// @"), []byte(""))
data = bytes.ReplaceAll(data, []byte("VALIGNQ"), []byte("VGF2P8AFFINEQB"))
data = bytes.ReplaceAll(data, []byte("VPTERNLOGQ"), []byte("XOR3WAY("))
split := bytes.Split(data, []byte("\n"))
// Add closing ')'
Expand Down
278 changes: 278 additions & 0 deletions _gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ const outputMax = 10
var switchDefs [inputMax][outputMax]string
var switchDefsX [inputMax][outputMax]string

var switchDefs512 [inputMax][outputMax]string
var switchDefsX512 [inputMax][outputMax]string

// Prefetch offsets, set to 0 to disable.
// Disabled since they appear to be consistently slower.
const prefetchSrc = 0
Expand Down Expand Up @@ -58,6 +61,8 @@ func main() {
for j := 1; j <= outputMax; j++ {
genMulAvx2(fmt.Sprintf("mulAvxTwo_%dx%d", i, j), i, j, false)
genMulAvx2Sixty64(fmt.Sprintf("mulAvxTwo_%dx%d_64", i, j), i, j, false)
genMulAvx512GFNI(fmt.Sprintf("mulGFNI_%dx%d_64", i, j), i, j, false)
genMulAvx512GFNI(fmt.Sprintf("mulGFNI_%dx%d_64Xor", i, j), i, j, true)
genMulAvx2(fmt.Sprintf("mulAvxTwo_%dx%dXor", i, j), i, j, true)
genMulAvx2Sixty64(fmt.Sprintf("mulAvxTwo_%dx%d_64Xor", i, j), i, j, true)
}
Expand Down Expand Up @@ -131,6 +136,48 @@ func galMulSlicesAvx2Xor(matrix []byte, in, out [][]byte, start, stop int) int {
panic(fmt.Sprintf("unhandled size: %dx%d", len(in), len(out)))
}
`)

w.WriteString(`

func galMulSlicesGFNI(matrix []uint64, in, out [][]byte, start, stop int) int {
n := (stop-start) & avxSizeMask

`)

w.WriteString(`switch len(in) {
`)
for in, defs := range switchDefs512[:] {
w.WriteString(fmt.Sprintf(" case %d:\n switch len(out) {\n", in+1))
for out, def := range defs[:] {
w.WriteString(fmt.Sprintf(" case %d:\n", out+1))
w.WriteString(def)
}
w.WriteString("}\n")
}
w.WriteString(`}
panic(fmt.Sprintf("unhandled size: %dx%d", len(in), len(out)))
}

func galMulSlicesGFNIXor(matrix []uint64, in, out [][]byte, start, stop int) int {
n := (stop-start) & avxSizeMask

`)

w.WriteString(`switch len(in) {
`)
for in, defs := range switchDefsX512[:] {
w.WriteString(fmt.Sprintf(" case %d:\n switch len(out) {\n", in+1))
for out, def := range defs[:] {
w.WriteString(fmt.Sprintf(" case %d:\n", out+1))
w.WriteString(def)
}
w.WriteString("}\n")
}
w.WriteString(`}
panic(fmt.Sprintf("unhandled size: %dx%d", len(in), len(out)))
}
`)

genGF16()
genGF8()
Generate()
Expand Down Expand Up @@ -657,3 +704,234 @@ func genMulAvx2Sixty64(name string, inputs int, outputs int, xor bool) {
Label(name + "_end")
RET()
}

func genMulAvx512GFNI(name string, inputs int, outputs int, xor bool) {
const perLoopBits = 6
const perLoop = 1 << perLoopBits

total := inputs * outputs

doc := []string{
fmt.Sprintf("%s takes %d inputs and produces %d outputs.", name, inputs, outputs),
}
if !xor {
doc = append(doc, "The output is initialized to 0.")
}

// Load shuffle masks on every use.
var loadNone bool
// Use registers for destination registers.
var regDst = true
var reloadLength = false

est := total + outputs + 2
// When we can't hold all, keep this many in registers.
inReg := 0
if est > 32 {
loadNone = true
inReg = 32 - outputs - 2
// We run out of GP registers first, now.
if inputs+outputs > 13 {
regDst = false
}
// Save one register by reloading length.
if inputs+outputs > 12 && regDst {
reloadLength = true
}
}

TEXT(name, 0, fmt.Sprintf("func(matrix []uint64, in [][]byte, out [][]byte, start, n int)"))
x := ""
if xor {
x = "Xor"
}
// SWITCH DEFINITION:
//s := fmt.Sprintf("n = (n>>%d)<<%d\n", perLoopBits, perLoopBits)
s := fmt.Sprintf(" mulGFNI_%dx%d_64%s(matrix, in, out, start, n)\n", inputs, outputs, x)
s += fmt.Sprintf("\t\t\t\treturn n\n")
if xor {
switchDefsX512[inputs-1][outputs-1] = s
} else {
switchDefs512[inputs-1][outputs-1] = s
}

if loadNone {
Commentf("Loading %d of %d tables to registers", inReg, inputs*outputs)
} else {
// loadNone == false
Comment("Loading all tables to registers")
}
if regDst {
Comment("Destination kept in GP registers")
} else {
Comment("Destination kept on stack")
}

Doc(doc...)
Pragma("noescape")
Commentf("Full registers estimated %d YMM used", est)

length := Load(Param("n"), GP64())
matrixBase := GP64()
addr, err := Param("matrix").Base().Resolve()
if err != nil {
panic(err)
}
MOVQ(addr.Addr, matrixBase)
SHRQ(U8(perLoopBits), length)
TESTQ(length, length)
JZ(LabelRef(name + "_end"))

matrix := make([]reg.VecVirtual, total)

for i := range matrix {
if loadNone && i >= inReg {
break
}
table := ZMM()
VBROADCASTF32X2(Mem{Base: matrixBase, Disp: i * 8}, table)
matrix[i] = table
}

inPtrs := make([]reg.GPVirtual, inputs)
inSlicePtr := GP64()
addr, err = Param("in").Base().Resolve()
if err != nil {
panic(err)
}
MOVQ(addr.Addr, inSlicePtr)
for i := range inPtrs {
ptr := GP64()
MOVQ(Mem{Base: inSlicePtr, Disp: i * 24}, ptr)
inPtrs[i] = ptr
}
// Destination
dst := make([]reg.VecVirtual, outputs)
dstPtr := make([]reg.GPVirtual, outputs)
addr, err = Param("out").Base().Resolve()
if err != nil {
panic(err)
}
outBase := addr.Addr
outSlicePtr := GP64()
MOVQ(addr.Addr, outSlicePtr)
MOVQ(outBase, outSlicePtr)
for i := range dst {
dst[i] = ZMM()
if !regDst {
continue
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
dstPtr[i] = ptr
}

offset := GP64()
addr, err = Param("start").Resolve()
if err != nil {
panic(err)
}

MOVQ(addr.Addr, offset)
if regDst {
Comment("Add start offset to output")
for _, ptr := range dstPtr {
ADDQ(offset, ptr)
}
}

Comment("Add start offset to input")
for _, ptr := range inPtrs {
ADDQ(offset, ptr)
}
// Offset no longer needed unless not regdst

if reloadLength {
Commentf("Reload length to save a register")
length = Load(Param("n"), GP64())
SHRQ(U8(perLoopBits), length)
}
Label(name + "_loop")

if xor {
Commentf("Load %d outputs", outputs)
for i := range dst {
if regDst {
VMOVDQU64(Mem{Base: dstPtr[i]}, dst[i])
if prefetchDst > 0 {
PREFETCHT0(Mem{Base: dstPtr[i], Disp: prefetchDst})
}
continue
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU64(Mem{Base: ptr, Index: offset, Scale: 1}, dst[i])

if prefetchDst > 0 {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
}
}

in := ZMM()
look := ZMM()
for i := range inPtrs {
Commentf("Load and process 64 bytes from input %d to %d outputs", i, outputs)
VMOVDQU64(Mem{Base: inPtrs[i]}, in)
if prefetchSrc > 0 {
PREFETCHT0(Mem{Base: inPtrs[i], Disp: prefetchSrc})
}
ADDQ(U8(perLoop), inPtrs[i])

for j := range dst {
idx := i*outputs + j
if loadNone && idx >= inReg {
VBROADCASTF32X2(Mem{Base: matrixBase, Disp: 8 * idx}, look)
if i == 0 && !xor {
// Converted to VGF2P8AFFINEQB
VALIGNQ(U8(0), look, in, dst[j])
} else {
// Converted to VGF2P8AFFINEQB
VALIGNQ(U8(0), look, in, look)
VXORPD(dst[j], look, dst[j])
}
} else {
if i == 0 && !xor {
// Converted to VGF2P8AFFINEQB
VALIGNQ(U8(0), matrix[i*outputs+j], in, dst[j])
} else {
// Converted to VGF2P8AFFINEQB
VALIGNQ(U8(0), matrix[i*outputs+j], in, look)
VXORPD(dst[j], look, dst[j])
}
}
}
}
Commentf("Store %d outputs", outputs)
for i := range dst {
if regDst {
VMOVDQU64(dst[i], Mem{Base: dstPtr[i]})
if prefetchDst > 0 && !xor {
PREFETCHT0(Mem{Base: dstPtr[i], Disp: prefetchDst})
}
ADDQ(U8(perLoop), dstPtr[i])
continue
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU64(dst[i], Mem{Base: ptr, Index: offset, Scale: 1})
if prefetchDst > 0 && !xor {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
}
Comment("Prepare for next loop")
if !regDst {
ADDQ(U8(perLoop), offset)
}
DECQ(length)
JNZ(LabelRef(name + "_loop"))
VZEROUPPER()

Label(name + "_end")
RET()
}
18 changes: 18 additions & 0 deletions galois.go
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,24 @@ func genAvx2Matrix(matrixRows [][]byte, inputs, inIdx, outputs int, dst []byte)
return dst
}

var gf2p811dMulMatrices = [256]uint64{0, 0x102040810204080, 0x8001828488102040, 0x8103868c983060c0, 0x408041c2c4881020, 0x418245cad4a850a0, 0xc081c3464c983060, 0xc183c74e5cb870e0, 0x2040a061e2c48810, 0x2142a469f2e4c890, 0xa04122e56ad4a850, 0xa14326ed7af4e8d0, 0x60c0e1a3264c9830, 0x61c2e5ab366cd8b0, 0xe0c16327ae5cb870, 0xe1c3672fbe7cf8f0, 0x102050b071e2c488, 0x112254b861c28408, 0x9021d234f9f2e4c8, 0x9123d63ce9d2a448, 0x50a01172b56ad4a8, 0x51a2157aa54a9428, 0xd0a193f63d7af4e8, 0xd1a397fe2d5ab468, 0x3060f0d193264c98, 0x3162f4d983060c18, 0xb06172551b366cd8, 0xb163765d0b162c58, 0x70e0b11357ae5cb8, 0x71e2b51b478e1c38, 0xf0e13397dfbe7cf8, 0xf1e3379fcf9e3c78, 0x8810a8d83871e2c4, 0x8912acd02851a244, 0x8112a5cb061c284, 0x9132e54a0418204, 0xc890e91afcf9f2e4, 0xc992ed12ecd9b264, 0x48916b9e74e9d2a4, 0x49936f9664c99224, 0xa85008b9dab56ad4, 0xa9520cb1ca952a54, 0x28518a3d52a54a94, 0x29538e3542850a14, 0xe8d0497b1e3d7af4, 0xe9d24d730e1d3a74, 0x68d1cbff962d5ab4, 0x69d3cff7860d1a34, 0x9830f8684993264c, 0x9932fc6059b366cc, 0x18317aecc183060c, 0x19337ee4d1a3468c, 0xd8b0b9aa8d1b366c, 0xd9b2bda29d3b76ec, 0x58b13b2e050b162c, 0x59b33f26152b56ac, 0xb8705809ab57ae5c, 0xb9725c01bb77eedc, 0x3871da8d23478e1c, 0x3973de853367ce9c, 0xf8f019cb6fdfbe7c, 0xf9f21dc37ffffefc, 0x78f19b4fe7cf9e3c, 0x79f39f47f7efdebc, 0xc488d46c1c3871e2, 0xc58ad0640c183162, 0x448956e8942851a2, 0x458b52e084081122, 0x840895aed8b061c2, 0x850a91a6c8902142, 0x409172a50a04182, 0x50b132240800102, 0xe4c8740dfefcf9f2, 0xe5ca7005eedcb972, 0x64c9f68976ecd9b2, 0x65cbf28166cc9932, 0xa44835cf3a74e9d2, 0xa54a31c72a54a952, 0x2449b74bb264c992, 0x254bb343a2448912, 0xd4a884dc6ddab56a, 0xd5aa80d47dfaf5ea, 0x54a90658e5ca952a, 0x55ab0250f5ead5aa, 0x9428c51ea952a54a, 0x952ac116b972e5ca, 0x1429479a2142850a, 0x152b43923162c58a, 0xf4e824bd8f1e3d7a, 0xf5ea20b59f3e7dfa, 0x74e9a639070e1d3a, 0x75eba231172e5dba, 0xb468657f4b962d5a, 0xb56a61775bb66dda, 0x3469e7fbc3860d1a, 0x356be3f3d3a64d9a, 0x4c987cb424499326, 0x4d9a78bc3469d3a6, 0xcc99fe30ac59b366, 0xcd9bfa38bc79f3e6, 0xc183d76e0c18306, 0xd1a397ef0e1c386, 0x8c19bff268d1a346, 0x8d1bbbfa78f1e3c6, 0x6cd8dcd5c68d1b36, 0x6ddad8ddd6ad5bb6, 0xecd95e514e9d3b76, 0xeddb5a595ebd7bf6, 0x2c589d1702050b16, 0x2d5a991f12254b96, 0xac591f938a152b56, 0xad5b1b9b9a356bd6, 0x5cb82c0455ab57ae, 0x5dba280c458b172e, 0xdcb9ae80ddbb77ee, 0xddbbaa88cd9b376e, 0x1c386dc69123478e, 0x1d3a69ce8103070e, 0x9c39ef42193367ce, 0x9d3beb4a0913274e, 0x7cf88c65b76fdfbe, 0x7dfa886da74f9f3e, 0xfcf90ee13f7ffffe, 0xfdfb0ae92f5fbf7e, 0x3c78cda773e7cf9e, 0x3d7ac9af63c78f1e, 0xbc794f23fbf7efde, 0xbd7b4b2bebd7af5e, 0xe2c46a368e1c3871, 0xe3c66e3e9e3c78f1, 0x62c5e8b2060c1831, 0x63c7ecba162c58b1, 0xa2442bf44a942851, 0xa3462ffc5ab468d1, 0x2245a970c2840811, 0x2347ad78d2a44891, 0xc284ca576cd8b061, 0xc386ce5f7cf8f0e1, 0x428548d3e4c89021, 0x43874cdbf4e8d0a1, 0x82048b95a850a041, 0x83068f9db870e0c1, 0x205091120408001, 0x3070d193060c081, 0xf2e43a86fffefcf9, 0xf3e63e8eefdebc79, 0x72e5b80277eedcb9, 0x73e7bc0a67ce9c39, 0xb2647b443b76ecd9, 0xb3667f4c2b56ac59, 0x3265f9c0b366cc99, 0x3367fdc8a3468c19, 0xd2a49ae71d3a74e9, 0xd3a69eef0d1a3469, 0x52a51863952a54a9, 0x53a71c6b850a1429, 0x9224db25d9b264c9, 0x9326df2dc9922449, 0x122559a151a24489, 0x13275da941820409, 0x6ad4c2eeb66ddab5, 0x6bd6c6e6a64d9a35, 0xead5406a3e7dfaf5, 0xebd744622e5dba75, 0x2a54832c72e5ca95, 0x2b56872462c58a15, 0xaa5501a8faf5ead5, 0xab5705a0ead5aa55, 0x4a94628f54a952a5, 0x4b96668744891225, 0xca95e00bdcb972e5, 0xcb97e403cc993265, 0xa14234d90214285, 0xb16274580010205, 0x8a15a1c9183162c5, 0x8b17a5c108112245, 0x7af4925ec78f1e3d, 0x7bf69656d7af5ebd, 0xfaf510da4f9f3e7d, 0xfbf714d25fbf7efd, 0x3a74d39c03070e1d, 0x3b76d79413274e9d, 0xba7551188b172e5d, 0xbb7755109b376edd, 0x5ab4323f254b962d, 0x5bb63637356bd6ad, 0xdab5b0bbad5bb66d, 0xdbb7b4b3bd7bf6ed, 0x1a3473fde1c3860d, 0x1b3677f5f1e3c68d, 0x9a35f17969d3a64d, 0x9b37f57179f3e6cd, 0x264cbe5a92244993, 0x274eba5282040913, 0xa64d3cde1a3469d3, 0xa74f38d60a142953, 0x66ccff9856ac59b3, 0x67cefb90468c1933, 0xe6cd7d1cdebc79f3, 0xe7cf7914ce9c3973, 0x60c1e3b70e0c183, 0x70e1a3360c08103, 0x860d9cbff8f0e1c3, 0x870f98b7e8d0a143, 0x468c5ff9b468d1a3, 0x478e5bf1a4489123, 0xc68ddd7d3c78f1e3, 0xc78fd9752c58b163, 0x366ceeeae3c68d1b, 0x376eeae2f3e6cd9b, 0xb66d6c6e6bd6ad5b, 0xb76f68667bf6eddb, 0x76ecaf28274e9d3b, 0x77eeab20376eddbb, 0xf6ed2dacaf5ebd7b, 0xf7ef29a4bf7efdfb, 0x162c4e8b0102050b, 0x172e4a831122458b, 0x962dcc0f8912254b, 0x972fc807993265cb, 0x56ac0f49c58a152b, 0x57ae0b41d5aa55ab, 0xd6ad8dcd4d9a356b, 0xd7af89c55dba75eb, 0xae5c1682aa55ab57, 0xaf5e128aba75ebd7, 0x2e5d940622458b17, 0x2f5f900e3265cb97, 0xeedc57406eddbb77, 0xefde53487efdfbf7, 0x6eddd5c4e6cd9b37, 0x6fdfd1ccf6eddbb7, 0x8e1cb6e348912347, 0x8f1eb2eb58b163c7, 0xe1d3467c0810307, 0xf1f306fd0a14387, 0xce9cf7218c193367, 0xcf9ef3299c3973e7, 0x4e9d75a504091327, 0x4f9f71ad142953a7, 0xbe7c4632dbb76fdf, 0xbf7e423acb972f5f, 0x3e7dc4b653a74f9f, 0x3f7fc0be43870f1f, 0xfefc07f01f3f7fff, 0xfffe03f80f1f3f7f, 0x7efd8574972f5fbf, 0x7fff817c870f1f3f, 0x9e3ce6533973e7cf, 0x9f3ee25b2953a74f, 0x1e3d64d7b163c78f, 0x1f3f60dfa143870f, 0xdebca791fdfbf7ef, 0xdfbea399eddbb76f, 0x5ebd251575ebd7af, 0x5fbf211d65cb972f}

func genGFNIMatrix(matrixRows [][]byte, inputs, inIdx, outputs int, dst []uint64) []uint64 {
if !avx2CodeGen {
panic("codegen not enabled")
}
total := inputs * outputs

// Duplicated in+out
dst = dst[:total]
for i, row := range matrixRows[:outputs] {
for j, idx := range row[inIdx : inIdx+inputs] {
dst[j*outputs+i] = gf2p811dMulMatrices[idx]
}
}
return dst
}

// xor slices writing to out.
func sliceXorGo(in, out []byte, _ *options) {
for len(out) >= 32 {
Expand Down
Loading