Skip to content

Commit

Permalink
performance and small api changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Franzen committed Mar 20, 2017
1 parent 826d84d commit 9bbbd0b
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 57 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func main() {

var zero hnsw.Point = make([]float32, 128)

h := hnsw.New(M, efConstruction, &zero)
h := hnsw.New(M, efConstruction, zero)
h.Grow(10000)

// Note that added ID:s must start from 1
Expand All @@ -50,12 +50,12 @@ func main() {
fmt.Printf("%v queries / second (single thread)\n", 1000.0/stop.Seconds())
}

func randomPoint() *hnsw.Point {
func randomPoint() hnsw.Point {
var v hnsw.Point = make([]float32, 128)
for i := range v {
v[i] = rand.Float32()
}
return &v
return v
}

```
10 changes: 5 additions & 5 deletions examples/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"math/rand"
"time"

".."
hnsw "github.com/Bithack/go-hnsw"
)

func main() {
Expand All @@ -19,7 +19,7 @@ func main() {

var zero hnsw.Point = make([]float32, 128)

h := hnsw.New(M, efConstruction, &zero)
h := hnsw.New(M, efConstruction, zero)
h.Grow(10000)

for i := 1; i <= 10000; i++ {
Expand All @@ -30,7 +30,7 @@ func main() {
}

fmt.Printf("Generating queries and calculating true answers using bruteforce search...\n")
queries := make([]*hnsw.Point, 1000)
queries := make([]hnsw.Point, 1000)
truth := make([][]uint32, 1000)
for i := range queries {
queries[i] = randomPoint()
Expand Down Expand Up @@ -63,10 +63,10 @@ func main() {

}

func randomPoint() *hnsw.Point {
func randomPoint() hnsw.Point {
var v hnsw.Point = make([]float32, 128)
for i := range v {
v[i] = rand.Float32()
}
return &v
return v
}
2 changes: 2 additions & 0 deletions f32/f32_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
package f32

func L2Squared(x, y []float32) float32

func L2Squared8AVX(x, y []float32) float32
64 changes: 64 additions & 0 deletions f32/l2squared8_avx_amd64.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//+build !noasm,!appengine

#include "textflag.h"

// This version is AVX optimized for vectors where the dimension is a multiple of 8
// Latest GO versions seems to align []float32 slices on 32-bytes on a 64-bit system, so we skip checks for this...

// func L2Squared8AVX(x, y []float32) (sum float32)
TEXT ·L2Squared8AVX(SB), NOSPLIT, $0
MOVQ x_base+0(FP), SI // SI = &x
MOVQ x_len+8(FP), AX // AX = len(x)
MOVQ y_base+24(FP), DI // DI = &y

MOVQ AX, BX // BX = len(x)

SHLQ $2, AX
ADDQ AX, SI
ADDQ AX, DI
SHRQ $2, AX
NEGQ AX

BYTE $0xc5; BYTE $0xfc; BYTE $0x57; BYTE $0xc0 // vxorps ymm0,ymm0,ymm0

ANDQ $0xF, BX // BX = len % 16
JZ l2_loop_16

// PRE LOOP, 8 values
BYTE $0xc5; BYTE $0xfc; BYTE $0x28; BYTE $0x0c; BYTE $0x86 //vmovaps ymm1,YMMWORD PTR [esi+eax*4]
BYTE $0xc5; BYTE $0xf4; BYTE $0x5c; BYTE $0x0c; BYTE $0x87 // vsubps ymm1,ymm1,YMMWORD PTR [edi+eax*4]
BYTE $0xc5; BYTE $0xf4; BYTE $0x59; BYTE $0xc9
BYTE $0xc5; BYTE $0xfc; BYTE $0x58; BYTE $0xc1;
ADDQ $8, AX

l2_loop_16:
BYTE $0xc5; BYTE $0xfc; BYTE $0x28; BYTE $0x0c; BYTE $0x86 //vmovaps ymm1,YMMWORD PTR [esi+eax*4]
BYTE $0xc5; BYTE $0xfc; BYTE $0x28; BYTE $0x54; BYTE $0x86; BYTE $0x20 //vmovaps ymm2,YMMWORD PTR [esi+eax*4+0x20]
BYTE $0xc5; BYTE $0xf4; BYTE $0x5c; BYTE $0x0c; BYTE $0x87 // vsubps ymm1,ymm1,YMMWORD PTR [edi+eax*4]
BYTE $0xc5; BYTE $0xec; BYTE $0x5c; BYTE $0x54; BYTE $0x87; BYTE $0x20 // vsubps ymm2,ymm2,YMMWORD PTR [edi+eax*4+0x20]
BYTE $0xc5; BYTE $0xf4; BYTE $0x59; BYTE $0xc9 // vmulps ymmX,ymmX,ymmX
BYTE $0xc5; BYTE $0xec; BYTE $0x59; BYTE $0xd2
BYTE $0xc5; BYTE $0xfc; BYTE $0x58; BYTE $0xc1; // vaddps ymm0,ymm0,ymmX
BYTE $0xc5; BYTE $0xfc; BYTE $0x58; BYTE $0xc2;
ADDQ $16, AX // eax += 16
JS l2_loop_16 // jump if negative

l2_end:
//auto x = _mm256_permute2f128_ps(v, v, 1);
BYTE $0xc4; BYTE $0xe3; BYTE $0x7d; BYTE $0x06; BYTE $0xc8; BYTE $0x01; // vperm2f128 ymm1,ymm0,ymm0,0x1
//auto y = _mm256_add_ps(v, x);
BYTE $0xc5;BYTE $0xfc; BYTE $0x58;BYTE $0xc1; // vaddps ymm0,ymm0,ymm1
//x = _mm256_shuffle_ps(y, y, _MM_SHUFFLE(2, 3, 0, 1)=0xB1);
//_MM_SHUFFLE
BYTE $0xc5;BYTE $0xfc;BYTE $0xc6;BYTE $0xc8; BYTE $0xb1 // vshufps ymm1,ymm0,ymm0,0xb1
//x = _mm256_add_ps(x, y);
BYTE $0xc5;BYTE $0xf4; BYTE $0x58;BYTE $0xc8 // vaddps ymm1,ymm1,ymm0
//y = _mm256_shuffle_ps(x, x, _MM_SHUFFLE(1, 0, 3, 2)=0x8E);
BYTE $0xc5;BYTE $0xf4; BYTE $0xc6;BYTE $0xc1; BYTE $0x8e // vshufps ymm0,ymm1,ymm1,0x8e
//return _mm256_add_ps(x, y);
BYTE $0xc5; BYTE $0xf4; BYTE $0x58; BYTE $0xc8 // vaddps ymm1,ymm1,ymm0

VZEROUPPER
MOVSS X1, ret+48(FP) // Return final sum.

RET
1 change: 0 additions & 1 deletion f32/l2squared_amd64.s
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

// This is the 16-byte SSE2 version.
// It skips pointer alignment checks, since latest GO versions seems to align all []float32 slices on 16-bytes
// TODO write the 32-byte AVX version!

// func L2Squared(x, y []float32) (sum float32)
TEXT ·L2Squared(SB), NOSPLIT, $0
Expand Down
17 changes: 17 additions & 0 deletions f32/l2squared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ func TestAlignment(t *testing.T) {
}
}

func Test24(t *testing.T) {
a := []float32{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}
b := []float32{4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4}
assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect")
assert.Equal(t, DistGo(b, a), L2Squared8AVX(a, b), "8avx Incorrect")
}

func Test128(t *testing.T) {
a := []float32{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4,
Expand All @@ -63,6 +70,7 @@ func Test128(t *testing.T) {
4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
}
assert.Equal(t, DistGo(a, b), L2Squared(a, b), "Incorrect")
assert.Equal(t, DistGo(b, a), L2Squared8AVX(a, b), "8avx Incorrect")
}

func TestBenchmark(t *testing.T) {
Expand All @@ -77,13 +85,22 @@ func TestBenchmark(t *testing.T) {
4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4, 4, 3, 1, 4,
}
l := 10000000
fmt.Printf("Testing %v calls with %v dim []float32\n", l, len(a2))

start := time.Now()
for i := 0; i < l; i++ {
L2Squared(a2, b2)
}
stop := time.Since(start)
fmt.Printf("l2squared Done in %v. %v calcs / second\n", stop, float64(l)/stop.Seconds())

start = time.Now()
for i := 0; i < l; i++ {
L2Squared8AVX(a2, b2)
}
stop = time.Since(start)
fmt.Printf("l2squared8AVX Done in %v. %v calcs / second\n", stop, float64(l)/stop.Seconds())

start = time.Now()
for i := 0; i < l; i++ {
DistGo(a2, b2)
Expand Down
Loading

0 comments on commit 9bbbd0b

Please sign in to comment.