Skip to content

Commit

Permalink
Make ARM SVE code vector length agnostic (#285)
Browse files Browse the repository at this point in the history
* Make ARM SVE code vector length agnostic

* Generate correct matrix for code-gen based on actual vector length (for 256 bits and below)

* Missing changes in reedsolomon.go

* Fix build for testing on amd64
  • Loading branch information
fwessels authored Aug 23, 2024
1 parent 3412d52 commit 67157af
Show file tree
Hide file tree
Showing 12 changed files with 499 additions and 254 deletions.
141 changes: 141 additions & 0 deletions _gen/gen-arm-sve.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log"
"os"
"regexp"
"slices"
"strconv"
"strings"

Expand Down Expand Up @@ -359,3 +360,143 @@ func genArmSve() {
fromAvx2ToSve()
addEarlyExit("Sve")
}

func assemble(sve string) string {
opcode, err := sve_as.Assemble(sve)
if err != nil {
return fmt.Sprintf(" WORD $0x00000000 // %s", sve)
} else {
return fmt.Sprintf(" WORD $0x%08x // %s", opcode, sve)
}
}

func addArmSveVectorLength() (addInits []string) {
const filename = "../galois_gen_arm64.s"
asmOut := &bytes.Buffer{}

file, err := os.Open(filename)
if err != nil {
return
}
defer file.Close()

// Create a scanner to read the file line by line
scanner := bufio.NewScanner(file)

routine := ""
addInits = make([]string, 0)

// Iterate over each line
for scanner.Scan() {
line := scanner.Text()

if strings.HasPrefix(line, "TEXT ·") {
routine = line
}

correctShift := func(shift, vl string) {
if strings.Contains(line, " // lsr ") && strings.HasSuffix(strings.TrimSpace(line), ", "+shift) {
instr := strings.Split(strings.TrimSpace(line), "// lsr ")[1]
args := strings.Split(instr, ", ")
if len(args) == 3 && args[0] == args[1] {
// keep the original right shift, but reverse the effect (so effectively
// clearing out the lower bits so we cannot do eg. "half loops" )
line += "\n"
line += assemble(fmt.Sprintf("lsl %s, %s, %s", args[0], args[1], shift)) + "\n"
line += assemble(fmt.Sprintf("rdvl x16, %s", vl)) + "\n"
line += assemble(fmt.Sprintf("udiv %s, %s, x16", args[0], args[1]))
}
}
}

correctShift("#6", "#2")
correctShift("#5", "#1")

if strings.Contains(line, " // add ") && strings.HasSuffix(strings.TrimSpace(line), "#64") {
instr := strings.Split(strings.TrimSpace(line), "// add ")[1]
args := strings.Split(instr, ", ")
if len(args) == 3 && args[0] == args[1] {
line = assemble(fmt.Sprintf("addvl %s, %s, #2", args[0], args[1]))
}
}

if strings.Contains(line, " // add ") && strings.HasSuffix(strings.TrimSpace(line), "#32") {
instr := strings.Split(strings.TrimSpace(line), "// add ")[1]
args := strings.Split(instr, ", ")
if len(args) == 3 && args[0] == args[1] {
line = assemble(fmt.Sprintf("addvl %s, %s, #1", args[0], args[1]))
}
}

if strings.Contains(line, " // add ") && strings.HasSuffix(strings.TrimSpace(line), "#4") {
// mark routine as needing initialization of register 17
addInits = append(addInits, routine)
line = assemble("add x15, x15, x17")
}

asmOut.WriteString(line + "\n")
}

// Check for any errors that occurred during scanning
if err = scanner.Err(); err != nil {
log.Fatal(err)
} else if err = os.WriteFile("../galois_gen_arm64.s", asmOut.Bytes(), 0644); err != nil {
log.Fatal(err)
}

return
}

func addArmSveInitializations(addInits []string) {

const filename = "../galois_gen_arm64.s"
asmOut := &bytes.Buffer{}

file, err := os.Open(filename)
if err != nil {
return
}
defer file.Close()

// Create a scanner to read the file line by line
scanner := bufio.NewScanner(file)
routine := ""
checkNextLine := false

// Iterate over each line
for scanner.Scan() {
line := scanner.Text()

if strings.HasPrefix(line, "TEXT ·") {
routine = line
}

if strings.Contains(line, "// Load number of input shards") {
checkNextLine = true
} else {
if checkNextLine {
idx := slices.IndexFunc(addInits, func(s string) bool { return s == routine })
if idx != -1 {
line += "\n"
line += assemble("rdvl x17, #1") + "\n"
line += assemble("lsr x17, x17, #3")
}
checkNextLine = false
}
}

asmOut.WriteString(line + "\n")
}

// Check for any errors that occurred during scanning
if err = scanner.Err(); err != nil {
log.Fatal(err)
} else if err = os.WriteFile("../galois_gen_arm64.s", asmOut.Bytes(), 0644); err != nil {
log.Fatal(err)
}
}

func genArmSveAllVl() {
addInits := addArmSveVectorLength()
addArmSveInitializations(addInits)
}
1 change: 1 addition & 0 deletions _gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func main() {
if pshufb {
genArmSve()
genArmNeon()
genArmSveAllVl()
}
Generate()
}
Expand Down
2 changes: 1 addition & 1 deletion _gen/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (

require (
github.com/fwessels/avxTwo2sve v0.0.0-20240611172111-6b8528700471 // indirect
github.com/fwessels/sve-as v0.0.0-20240611015707-daffc010447f // indirect
github.com/fwessels/sve-as v0.0.0-20240817192210-83d5dbff9505 // indirect
golang.org/x/mod v0.6.0 // indirect
golang.org/x/sys v0.1.0 // indirect
golang.org/x/tools v0.2.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions _gen/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ github.com/fwessels/avxTwo2sve v0.0.0-20240611172111-6b8528700471 h1:omdgAKxePZx
github.com/fwessels/avxTwo2sve v0.0.0-20240611172111-6b8528700471/go.mod h1:9+ibRsEIs0vLXkalKCGEbZfVS4fafeIvMvM9GvIsdeQ=
github.com/fwessels/sve-as v0.0.0-20240611015707-daffc010447f h1:HQud3yIU82LdkQzHEYiSJs73wCHjprIqeZE9JvSjKbQ=
github.com/fwessels/sve-as v0.0.0-20240611015707-daffc010447f/go.mod h1:j3s7EY79XxNMyjx/54Vo6asZafWU4yijB+KIfj4hrh8=
github.com/fwessels/sve-as v0.0.0-20240817192210-83d5dbff9505 h1:oKLoVXrXDsNNTdNLsSbEu18Vy0Z0b1yeanl5TG4qSyU=
github.com/fwessels/sve-as v0.0.0-20240817192210-83d5dbff9505/go.mod h1:j3s7EY79XxNMyjx/54Vo6asZafWU4yijB+KIfj4hrh8=
github.com/klauspost/asmfmt v1.3.1 h1:7xZi1N7s9gTLbqiM8KUv8TLyysavbTRGBT5/ly0bRtw=
github.com/klauspost/asmfmt v1.3.1/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
github.com/mmcloughlin/avo v0.5.1-0.20221128045730-bf1d05562091 h1:C2c8ttOBeyhs1SvyCXVPCFd0EqtPiTKGnMWQ+JkM0Lc=
Expand Down
17 changes: 9 additions & 8 deletions galois.go
Original file line number Diff line number Diff line change
Expand Up @@ -910,30 +910,31 @@ func galExp(a byte, n int) byte {
return expTable[uint8(logResult)]
}

func genCodeGenMatrix(matrixRows [][]byte, inputs, inIdx, outputs int, dst []byte) []byte {
func genCodeGenMatrix(matrixRows [][]byte, inputs, inIdx, outputs, vectorLength int, dst []byte) []byte {
if !codeGen {
panic("codegen not enabled")
}
total := inputs * outputs

// Duplicated in+out
wantBytes := total * 32 * 2
wantBytes := total * vectorLength * 2
if cap(dst) < wantBytes {
dst = AllocAligned(1, wantBytes)[0]
} else {
dst = dst[:wantBytes]
}
for i, row := range matrixRows[:outputs] {
for j, idx := range row[inIdx : inIdx+inputs] {
dstIdx := (j*outputs + i) * 64
dstIdx := (j*outputs + i) * vectorLength * 2
dstPart := dst[dstIdx:]
dstPart = dstPart[:64]
dstPart = dstPart[:vectorLength*2]
lo := mulTableLow[idx][:]
hi := mulTableHigh[idx][:]
copy(dstPart[:16], lo)
copy(dstPart[16:32], lo)
copy(dstPart[32:48], hi)
copy(dstPart[48:64], hi)

for k := 0; k < vectorLength; k += 16 {
copy(dstPart[k:k+16], lo)
copy(dstPart[vectorLength*2-(k+16):vectorLength*2-k], hi)
}
}
}
return dst
Expand Down
2 changes: 1 addition & 1 deletion galois_amd64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ import (

func TestGenGalois(t *testing.T) {
if defaultOptions.useAVX2 {
testGenGaloisUpto10x10(t, galMulSlicesAvx2, galMulSlicesAvx2Xor)
testGenGaloisUpto10x10(t, galMulSlicesAvx2, galMulSlicesAvx2Xor, 32)
}
}
8 changes: 6 additions & 2 deletions galois_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ func getVectorLength() (vl, pl uint64)

func init() {
if defaultOptions.useSVE {
if vl, _ := getVectorLength(); vl != 256 {
defaultOptions.useSVE = false // Temp fix: disable SVE for non-256 vector widths (ie Graviton4)
if vl, _ := getVectorLength(); vl <= 256 {
// set vector length in bytes
defaultOptions.vectorLength = int(vl) >> 3
} else {
// disable SVE for hardware implementatons over 256 bits (only know to be Fujitsu A64FX atm)
defaultOptions.useSVE = false
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions galois_arm64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (

func TestGenGalois(t *testing.T) {
if defaultOptions.useSVE {
testGenGaloisUpto10x10(t, galMulSlicesSve, galMulSlicesSveXor)
testGenGaloisUpto10x10(t, galMulSlicesSve, galMulSlicesSveXor, defaultOptions.vectorLength)
}
if defaultOptions.useNEON {
testGenGaloisUpto10x10(t, galMulSlicesNeon, galMulSlicesNeonXor)
testGenGaloisUpto10x10(t, galMulSlicesNeon, galMulSlicesNeonXor, 32)
}
}
Loading

0 comments on commit 67157af

Please sign in to comment.