Skip to content

Commit

Permalink
fixed l2squared for different lengths + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Franzen committed Mar 16, 2017
1 parent 4b5ebe9 commit c8df852
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 28 deletions.
42 changes: 14 additions & 28 deletions f32/l2squared_amd64.s
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

#include "textflag.h"

// This is the 16-byte SSE2 version.
// It skips pointer alignment checks, since according to the test GO seems to align all []float32 slices on 32-bytes
// TODO write the 32-byte AVX version!

// func L2Squared(x, y []float32) (sum float32)
TEXT ·L2Squared(SB), NOSPLIT, $0
MOVQ x_base+0(FP), SI // SI = &x
Expand All @@ -16,27 +20,9 @@ TEXT ·L2Squared(SB), NOSPLIT, $0
MOVSD $(0.0), X1 // sum = 0

XORQ AX, AX // i = 0
PXOR X2, X2 // 2 NOP instructions (PXOR) to align
PXOR X3, X3 // loop to cache line
MOVQ DI, CX
ANDQ $0xF, CX // Align on 16-byte boundary for ADDPS
JZ l2_no_trim // if CX == 0 { goto l2_no_trim }

XORQ $0xF, CX // CX = 4 - floor( BX % 16 / 4 )
INCQ CX
SHRQ $2, CX
//PXOR X2, X2 // 2 NOP instructions (PXOR) to align
//PXOR X3, X3 // loop to cache line

l2_align: // Trim first value(s) in unaligned buffer do {
MOVSS (SI)(AX*4), X2 // X2 = x[i]
MULSS X0, X2 // X2 *= a
ADDSS (DI)(AX*4), X2 // X2 += y[i]
MOVSS X2, (DI)(AX*4) // y[i] = X2
INCQ AX // i++
DECQ BX
JZ l2_end // if --BX == 0 { return }
LOOP l2_align // } while --CX > 0

l2_no_trim:
MOVQ BX, CX
ANDQ $0xF, BX // BX = len % 16
SHRQ $4, CX // CX = int( len / 16 )
Expand Down Expand Up @@ -74,10 +60,10 @@ l2_tail4_start: // Reset loop counter for 4-wide tail loop
JZ l2_tail_start // if CX == 0 { goto l2_tail_start }

l2_tail4: // Loop unrolled 4x do {
MOVUPS (SI)(AX*4), X2 // X2 = x[i]
MULPS X0, X2 // X2 *= a
ADDPS (DI)(AX*4), X2 // X2 += y[i]
MOVUPS X2, (DI)(AX*4) // y[i] = X2
MOVUPS (SI)(AX*4), X2 // X2 = x[i]
SUBPS (DI)(AX*4), X2 // X2 -= y[i:i+4]
MULPS X2, X2 // X2 *= X2
ADDPS X2, X1 // X1 += X2
ADDQ $4, AX // i += 4
LOOP l2_tail4 // } while --CX > 0

Expand All @@ -87,10 +73,10 @@ l2_tail_start: // Reset loop counter for 1-wide tail loop
JZ l2_end // if CX == 0 { return }

l2_tail:
MOVSS (SI)(AX*4), X1 // X1 = x[i]
MULSS X0, X1 // X1 *= a
ADDSS (DI)(AX*4), X1 // X1 += y[i]
MOVSS X1, (DI)(AX*4) // y[i] = X1
MOVSS (SI)(AX*4), X2 // X1 = x[i]
SUBSS (DI)(AX*4), X2 // X1 -= y[i]
MULSS X2, X2 // X1 *= a
ADDSS X2, X1 // sum += X2
INCQ AX // i++
LOOP l2_tail // } while --CX > 0

Expand Down
93 changes: 93 additions & 0 deletions f32/l2squared_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package f32

import (
"fmt"
"math/rand"
"testing"
"time"
"unsafe"

"github.com/stretchr/testify/assert"
)

func DistGo(a, b []float32) (r float32) {
var d float32
for i := range a {
d = a[i] - b[i]
r += d * d
}
return r
}

func Test1(t *testing.T) {
a := []float32{1}
b := []float32{4}
assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect")
}

func Test4(t *testing.T) {
a := []float32{1, 2, 3, 4}
b := []float32{4, 3, 2, 1}
assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect")
}

func Test5(t *testing.T) {
a := []float32{1, 2, 3, 4, 1}
b := []float32{4, 3, 2, 1, 9}
assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect")
}

func Test21(t *testing.T) {
a := []float32{1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 1, 2, 3, 4, 1, 9}
b := []float32{4, 3, 2, 1, 9, 4, 3, 2, 1, 9, 4, 3, 2, 1, 9, 4, 3, 2, 1, 9, 0}
assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect")
}

func TestAlignment(t *testing.T) {
for i := 0; i < 10000; i++ {
a := make([]float32, rand.Intn(256))
assert.True(t, uintptr(unsafe.Pointer(&a))%16 == 0, "[]float32 Not 16-bytes aligned!")
assert.True(t, uintptr(unsafe.Pointer(&a))%32 == 0, "[]float32 Not 32-bytes aligned!")
}
}

func Test128(t *testing.T) {
a := []float32{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
}
b := []float32{4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
}
assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect")
}

func TestBenchmark(t *testing.T) {
a2 := []float32{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
}
b2 := []float32{4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
}
l := 10000000
start := time.Now()
for i := 0; i < l; i++ {
L2Squared(a2, b2)
}
stop := time.Since(start)
fmt.Printf("l2squared Done in %v. %v calcs / second\n", stop, float64(l)/stop.Seconds())

start = time.Now()
for i := 0; i < l; i++ {
DistGo(a2, b2)
}
stop = time.Since(start)
fmt.Printf("Go version done in %v. %v calcs / second\n", stop, float64(l)/stop.Seconds())
}

0 comments on commit c8df852

Please sign in to comment.