Skip to content

Commit

Permalink
Add AVX512 GFNI processing
Browse files Browse the repository at this point in the history
WIP
  • Loading branch information
klauspost committed Nov 15, 2022
1 parent 4d2013d commit bf0ce8a
Show file tree
Hide file tree
Showing 9 changed files with 36,999 additions and 21 deletions.
1 change: 1 addition & 0 deletions _gen/cleanup.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func main() {
}
data = bytes.ReplaceAll(data, []byte("\t// #"), []byte("#"))
data = bytes.ReplaceAll(data, []byte("\t// @"), []byte(""))
data = bytes.ReplaceAll(data, []byte("VALIGNQ"), []byte("VGF2P8AFFINEQB"))
data = bytes.ReplaceAll(data, []byte("VPTERNLOGQ"), []byte("XOR3WAY("))
split := bytes.Split(data, []byte("\n"))
// Add closing ')'
Expand Down
271 changes: 271 additions & 0 deletions _gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ const outputMax = 10
var switchDefs [inputMax][outputMax]string
var switchDefsX [inputMax][outputMax]string

var switchDefs512 [inputMax][outputMax]string
var switchDefsX512 [inputMax][outputMax]string

// Prefetch offsets, set to 0 to disable.
// Disabled since they appear to be consistently slower.
const prefetchSrc = 0
Expand Down Expand Up @@ -58,6 +61,8 @@ func main() {
for j := 1; j <= outputMax; j++ {
genMulAvx2(fmt.Sprintf("mulAvxTwo_%dx%d", i, j), i, j, false)
genMulAvx2Sixty64(fmt.Sprintf("mulAvxTwo_%dx%d_64", i, j), i, j, false)
genMulAvx512GFNI(fmt.Sprintf("mulGFNI_%dx%d_64", i, j), i, j, false)
genMulAvx512GFNI(fmt.Sprintf("mulGFNI_%dx%d_64Xor", i, j), i, j, true)
genMulAvx2(fmt.Sprintf("mulAvxTwo_%dx%dXor", i, j), i, j, true)
genMulAvx2Sixty64(fmt.Sprintf("mulAvxTwo_%dx%d_64Xor", i, j), i, j, true)
}
Expand Down Expand Up @@ -131,6 +136,48 @@ func galMulSlicesAvx2Xor(matrix []byte, in, out [][]byte, start, stop int) int {
panic(fmt.Sprintf("unhandled size: %dx%d", len(in), len(out)))
}
`)

w.WriteString(`
func galMulSlicesGFNI(matrix []uint64, in, out [][]byte, start, stop int) int {
n := (stop-start) & avxSizeMask
`)

w.WriteString(`switch len(in) {
`)
for in, defs := range switchDefs512[:] {
w.WriteString(fmt.Sprintf(" case %d:\n switch len(out) {\n", in+1))
for out, def := range defs[:] {
w.WriteString(fmt.Sprintf(" case %d:\n", out+1))
w.WriteString(def)
}
w.WriteString("}\n")
}
w.WriteString(`}
panic(fmt.Sprintf("unhandled size: %dx%d", len(in), len(out)))
}
func galMulSlicesGFNIXor(matrix []uint64, in, out [][]byte, start, stop int) int {
n := (stop-start) & avxSizeMask
`)

w.WriteString(`switch len(in) {
`)
for in, defs := range switchDefsX512[:] {
w.WriteString(fmt.Sprintf(" case %d:\n switch len(out) {\n", in+1))
for out, def := range defs[:] {
w.WriteString(fmt.Sprintf(" case %d:\n", out+1))
w.WriteString(def)
}
w.WriteString("}\n")
}
w.WriteString(`}
panic(fmt.Sprintf("unhandled size: %dx%d", len(in), len(out)))
}
`)

genGF16()
genGF8()
Generate()
Expand Down Expand Up @@ -657,3 +704,227 @@ func genMulAvx2Sixty64(name string, inputs int, outputs int, xor bool) {
Label(name + "_end")
RET()
}

func genMulAvx512GFNI(name string, inputs int, outputs int, xor bool) {
const perLoopBits = 6
const perLoop = 1 << perLoopBits

total := inputs * outputs

doc := []string{
fmt.Sprintf("%s takes %d inputs and produces %d outputs.", name, inputs, outputs),
}
if !xor {
doc = append(doc, "The output is initialized to 0.")
}

// Load shuffle masks on every use.
var loadNone bool
// Use registers for destination registers.
var regDst = true
var reloadLength = false

est := total + outputs + 2

if est > 32 {
loadNone = true
// We run out of GP registers first, now.
if inputs+outputs > 13 {
regDst = false
}
// Save one register by reloading length.
if inputs+outputs > 12 && regDst {
reloadLength = true
}
}

TEXT(name, 0, fmt.Sprintf("func(matrix []uint64, in [][]byte, out [][]byte, start, n int)"))
x := ""
if xor {
x = "Xor"
}
// SWITCH DEFINITION:
//s := fmt.Sprintf("n = (n>>%d)<<%d\n", perLoopBits, perLoopBits)
s := fmt.Sprintf(" mulGFNI_%dx%d_64%s(matrix, in, out, start, n)\n", inputs, outputs, x)
s += fmt.Sprintf("\t\t\t\treturn n\n")
if xor {
switchDefsX512[inputs-1][outputs-1] = s
} else {
switchDefs512[inputs-1][outputs-1] = s
}

if loadNone {
Comment("Loading no tables to registers")
} else {
// loadNone == false
Comment("Loading all tables to registers")
}
if regDst {
Comment("Destination kept in GP registers")
} else {
Comment("Destination kept on stack")
}

Doc(doc...)
Pragma("noescape")
Commentf("Full registers estimated %d YMM used", est)

length := Load(Param("n"), GP64())
matrixBase := GP64()
addr, err := Param("matrix").Base().Resolve()
if err != nil {
panic(err)
}
MOVQ(addr.Addr, matrixBase)
SHRQ(U8(perLoopBits), length)
TESTQ(length, length)
JZ(LabelRef(name + "_end"))

matrix := make([]reg.VecVirtual, total)

for i := range matrix {
if loadNone {
break
}
table := ZMM()
VBROADCASTF32X2(Mem{Base: matrixBase, Disp: i * 8}, table)
matrix[i] = table
}

inPtrs := make([]reg.GPVirtual, inputs)
inSlicePtr := GP64()
addr, err = Param("in").Base().Resolve()
if err != nil {
panic(err)
}
MOVQ(addr.Addr, inSlicePtr)
for i := range inPtrs {
ptr := GP64()
MOVQ(Mem{Base: inSlicePtr, Disp: i * 24}, ptr)
inPtrs[i] = ptr
}
// Destination
dst := make([]reg.VecVirtual, outputs)
dstPtr := make([]reg.GPVirtual, outputs)
addr, err = Param("out").Base().Resolve()
if err != nil {
panic(err)
}
outBase := addr.Addr
outSlicePtr := GP64()
MOVQ(addr.Addr, outSlicePtr)
MOVQ(outBase, outSlicePtr)
for i := range dst {
dst[i] = ZMM()
if !regDst {
continue
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
dstPtr[i] = ptr
}

offset := GP64()
addr, err = Param("start").Resolve()
if err != nil {
panic(err)
}

MOVQ(addr.Addr, offset)
if regDst {
Comment("Add start offset to output")
for _, ptr := range dstPtr {
ADDQ(offset, ptr)
}
}

Comment("Add start offset to input")
for _, ptr := range inPtrs {
ADDQ(offset, ptr)
}
// Offset no longer needed unless not regdst

if reloadLength {
Commentf("Reload length to save a register")
length = Load(Param("n"), GP64())
SHRQ(U8(perLoopBits), length)
}
Label(name + "_loop")

if xor {
Commentf("Load %d outputs", outputs)
for i := range dst {
if regDst {
VMOVDQU64(Mem{Base: dstPtr[i]}, dst[i])
if prefetchDst > 0 {
PREFETCHT0(Mem{Base: dstPtr[i], Disp: prefetchDst})
}
continue
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU64(Mem{Base: ptr, Index: offset, Scale: 1}, dst[i])

if prefetchDst > 0 {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
}
}

in := ZMM()
look := ZMM()
for i := range inPtrs {
Commentf("Load and process 64 bytes from input %d to %d outputs", i, outputs)
VMOVDQU64(Mem{Base: inPtrs[i]}, in)
if prefetchSrc > 0 {
PREFETCHT0(Mem{Base: inPtrs[i], Disp: prefetchSrc})
}
ADDQ(U8(perLoop), inPtrs[i])

for j := range dst {
if loadNone {
VBROADCASTF32X2(Mem{Base: matrixBase, Disp: 8 * (i*outputs + j)}, look)
if i == 0 && !xor {
VALIGNQ(U8(0), in, look, dst[j])
} else {
VALIGNQ(U8(0), in, look, look)
VXORPD(dst[j], look, dst[j])
}
} else {
if i == 0 && !xor {
VALIGNQ(U8(0), in, matrix[i*outputs+j], dst[j])
} else {
VALIGNQ(U8(0), in, matrix[i*outputs+j], look)
VXORPD(dst[j], look, dst[j])
}
}
}
}
Commentf("Store %d outputs", outputs)
for i := range dst {
if regDst {
VMOVDQU64(dst[i], Mem{Base: dstPtr[i]})
if prefetchDst > 0 && !xor {
PREFETCHT0(Mem{Base: dstPtr[i], Disp: prefetchDst})
}
ADDQ(U8(perLoop), dstPtr[i])
continue
}
ptr := GP64()
MOVQ(Mem{Base: outSlicePtr, Disp: i * 24}, ptr)
VMOVDQU64(dst[i], Mem{Base: ptr, Index: offset, Scale: 1})
if prefetchDst > 0 && !xor {
PREFETCHT0(Mem{Base: ptr, Disp: prefetchDst, Index: offset, Scale: 1})
}
}
Comment("Prepare for next loop")
if !regDst {
ADDQ(U8(perLoop), offset)
}
DECQ(length)
JNZ(LabelRef(name + "_loop"))
VZEROUPPER()

Label(name + "_end")
RET()
}
18 changes: 18 additions & 0 deletions galois.go
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,24 @@ func genAvx2Matrix(matrixRows [][]byte, inputs, inIdx, outputs int, dst []byte)
return dst
}

var gf2p811dMulMatrices = [256]uint64{0, 0x102040810204080, 0x8001828488102040, 0x8103868c983060c0, 0x408041c2c4881020, 0x418245cad4a850a0, 0xc081c3464c983060, 0xc183c74e5cb870e0, 0x2040a061e2c48810, 0x2142a469f2e4c890, 0xa04122e56ad4a850, 0xa14326ed7af4e8d0, 0x60c0e1a3264c9830, 0x61c2e5ab366cd8b0, 0xe0c16327ae5cb870, 0xe1c3672fbe7cf8f0, 0x102050b071e2c488, 0x112254b861c28408, 0x9021d234f9f2e4c8, 0x9123d63ce9d2a448, 0x50a01172b56ad4a8, 0x51a2157aa54a9428, 0xd0a193f63d7af4e8, 0xd1a397fe2d5ab468, 0x3060f0d193264c98, 0x3162f4d983060c18, 0xb06172551b366cd8, 0xb163765d0b162c58, 0x70e0b11357ae5cb8, 0x71e2b51b478e1c38, 0xf0e13397dfbe7cf8, 0xf1e3379fcf9e3c78, 0x8810a8d83871e2c4, 0x8912acd02851a244, 0x8112a5cb061c284, 0x9132e54a0418204, 0xc890e91afcf9f2e4, 0xc992ed12ecd9b264, 0x48916b9e74e9d2a4, 0x49936f9664c99224, 0xa85008b9dab56ad4, 0xa9520cb1ca952a54, 0x28518a3d52a54a94, 0x29538e3542850a14, 0xe8d0497b1e3d7af4, 0xe9d24d730e1d3a74, 0x68d1cbff962d5ab4, 0x69d3cff7860d1a34, 0x9830f8684993264c, 0x9932fc6059b366cc, 0x18317aecc183060c, 0x19337ee4d1a3468c, 0xd8b0b9aa8d1b366c, 0xd9b2bda29d3b76ec, 0x58b13b2e050b162c, 0x59b33f26152b56ac, 0xb8705809ab57ae5c, 0xb9725c01bb77eedc, 0x3871da8d23478e1c, 0x3973de853367ce9c, 0xf8f019cb6fdfbe7c, 0xf9f21dc37ffffefc, 0x78f19b4fe7cf9e3c, 0x79f39f47f7efdebc, 0xc488d46c1c3871e2, 0xc58ad0640c183162, 0x448956e8942851a2, 0x458b52e084081122, 0x840895aed8b061c2, 0x850a91a6c8902142, 0x409172a50a04182, 0x50b132240800102, 0xe4c8740dfefcf9f2, 0xe5ca7005eedcb972, 0x64c9f68976ecd9b2, 0x65cbf28166cc9932, 0xa44835cf3a74e9d2, 0xa54a31c72a54a952, 0x2449b74bb264c992, 0x254bb343a2448912, 0xd4a884dc6ddab56a, 0xd5aa80d47dfaf5ea, 0x54a90658e5ca952a, 0x55ab0250f5ead5aa, 0x9428c51ea952a54a, 0x952ac116b972e5ca, 0x1429479a2142850a, 0x152b43923162c58a, 0xf4e824bd8f1e3d7a, 0xf5ea20b59f3e7dfa, 0x74e9a639070e1d3a, 0x75eba231172e5dba, 0xb468657f4b962d5a, 0xb56a61775bb66dda, 0x3469e7fbc3860d1a, 0x356be3f3d3a64d9a, 0x4c987cb424499326, 0x4d9a78bc3469d3a6, 0xcc99fe30ac59b366, 0xcd9bfa38bc79f3e6, 0xc183d76e0c18306, 0xd1a397ef0e1c386, 0x8c19bff268d1a346, 0x8d1bbbfa78f1e3c6, 0x6cd8dcd5c68d1b36, 0x6ddad8ddd6ad5bb6, 0xecd95e514e9d3b76, 0xeddb5a595ebd7bf6, 0x2c589d1702050b16, 0x2d5a991f12254b96, 0xac591f938a152b56, 0xad5b1b9b9a356bd6, 0x5cb82c0455ab57ae, 0x5dba280c458b172e, 0xdcb9ae80ddbb77ee, 0xddbbaa88cd9b376e, 0x1c386dc69123478e, 0x1d3a69ce8103070e, 0x9c39ef42193367ce, 0x9d3beb4a0913274e, 0x7cf88c65b76fdfbe, 0x7dfa886da74f9f3e, 0xfcf90ee13f7ffffe, 0xfdfb0ae92f5fbf7e, 0x3c78cda773e7cf9e, 0x3d7ac9af63c78f1e, 0xbc794f23fbf7efde, 0xbd7b4b2bebd7af5e, 0xe2c46a368e1c3871, 0xe3c66e3e9e3c78f1, 0x62c5e8b2060c1831, 0x63c7ecba162c58b1, 0xa2442bf44a942851, 0xa3462ffc5ab468d1, 0x2245a970c2840811, 0x2347ad78d2a44891, 0xc284ca576cd8b061, 0xc386ce5f7cf8f0e1, 0x428548d3e4c89021, 0x43874cdbf4e8d0a1, 0x82048b95a850a041, 0x83068f9db870e0c1, 0x205091120408001, 0x3070d193060c081, 0xf2e43a86fffefcf9, 0xf3e63e8eefdebc79, 0x72e5b80277eedcb9, 0x73e7bc0a67ce9c39, 0xb2647b443b76ecd9, 0xb3667f4c2b56ac59, 0x3265f9c0b366cc99, 0x3367fdc8a3468c19, 0xd2a49ae71d3a74e9, 0xd3a69eef0d1a3469, 0x52a51863952a54a9, 0x53a71c6b850a1429, 0x9224db25d9b264c9, 0x9326df2dc9922449, 0x122559a151a24489, 0x13275da941820409, 0x6ad4c2eeb66ddab5, 0x6bd6c6e6a64d9a35, 0xead5406a3e7dfaf5, 0xebd744622e5dba75, 0x2a54832c72e5ca95, 0x2b56872462c58a15, 0xaa5501a8faf5ead5, 0xab5705a0ead5aa55, 0x4a94628f54a952a5, 0x4b96668744891225, 0xca95e00bdcb972e5, 0xcb97e403cc993265, 0xa14234d90214285, 0xb16274580010205, 0x8a15a1c9183162c5, 0x8b17a5c108112245, 0x7af4925ec78f1e3d, 0x7bf69656d7af5ebd, 0xfaf510da4f9f3e7d, 0xfbf714d25fbf7efd, 0x3a74d39c03070e1d, 0x3b76d79413274e9d, 0xba7551188b172e5d, 0xbb7755109b376edd, 0x5ab4323f254b962d, 0x5bb63637356bd6ad, 0xdab5b0bbad5bb66d, 0xdbb7b4b3bd7bf6ed, 0x1a3473fde1c3860d, 0x1b3677f5f1e3c68d, 0x9a35f17969d3a64d, 0x9b37f57179f3e6cd, 0x264cbe5a92244993, 0x274eba5282040913, 0xa64d3cde1a3469d3, 0xa74f38d60a142953, 0x66ccff9856ac59b3, 0x67cefb90468c1933, 0xe6cd7d1cdebc79f3, 0xe7cf7914ce9c3973, 0x60c1e3b70e0c183, 0x70e1a3360c08103, 0x860d9cbff8f0e1c3, 0x870f98b7e8d0a143, 0x468c5ff9b468d1a3, 0x478e5bf1a4489123, 0xc68ddd7d3c78f1e3, 0xc78fd9752c58b163, 0x366ceeeae3c68d1b, 0x376eeae2f3e6cd9b, 0xb66d6c6e6bd6ad5b, 0xb76f68667bf6eddb, 0x76ecaf28274e9d3b, 0x77eeab20376eddbb, 0xf6ed2dacaf5ebd7b, 0xf7ef29a4bf7efdfb, 0x162c4e8b0102050b, 0x172e4a831122458b, 0x962dcc0f8912254b, 0x972fc807993265cb, 0x56ac0f49c58a152b, 0x57ae0b41d5aa55ab, 0xd6ad8dcd4d9a356b, 0xd7af89c55dba75eb, 0xae5c1682aa55ab57, 0xaf5e128aba75ebd7, 0x2e5d940622458b17, 0x2f5f900e3265cb97, 0xeedc57406eddbb77, 0xefde53487efdfbf7, 0x6eddd5c4e6cd9b37, 0x6fdfd1ccf6eddbb7, 0x8e1cb6e348912347, 0x8f1eb2eb58b163c7, 0xe1d3467c0810307, 0xf1f306fd0a14387, 0xce9cf7218c193367, 0xcf9ef3299c3973e7, 0x4e9d75a504091327, 0x4f9f71ad142953a7, 0xbe7c4632dbb76fdf, 0xbf7e423acb972f5f, 0x3e7dc4b653a74f9f, 0x3f7fc0be43870f1f, 0xfefc07f01f3f7fff, 0xfffe03f80f1f3f7f, 0x7efd8574972f5fbf, 0x7fff817c870f1f3f, 0x9e3ce6533973e7cf, 0x9f3ee25b2953a74f, 0x1e3d64d7b163c78f, 0x1f3f60dfa143870f, 0xdebca791fdfbf7ef, 0xdfbea399eddbb76f, 0x5ebd251575ebd7af, 0x5fbf211d65cb972f}

func genGFNIMatrix(matrixRows [][]byte, inputs, inIdx, outputs int, dst []uint64) []uint64 {
if !avx2CodeGen {
panic("codegen not enabled")
}
total := inputs * outputs

// Duplicated in+out
dst = dst[:total]
for i, row := range matrixRows[:outputs] {
for j, idx := range row[inIdx : inIdx+inputs] {
dst[j*outputs+i] = gf2p811dMulMatrices[idx]
}
}
return dst
}

// xor slices writing to out.
func sliceXorGo(in, out []byte, _ *options) {
for len(out) >= 32 {
Expand Down
Loading

0 comments on commit bf0ce8a

Please sign in to comment.