Skip to content

Commit

Permalink
added save/load to disc of index
Browse files Browse the repository at this point in the history
  • Loading branch information
Oscar Franzén committed Jun 29, 2017
1 parent 9bbbd0b commit 3b020fb
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 11 deletions.
143 changes: 143 additions & 0 deletions hnsw.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package hnsw

import (
"compress/gzip"
"encoding/binary"
"fmt"
"io"
"math"
"math/rand"
"os"
"sync"

"github.com/Bithack/go-hnsw/bitsetpool"
Expand Down Expand Up @@ -44,6 +48,145 @@ type Hnsw struct {
enterpoint uint32
}

func Load(filename string) (*Hnsw, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
z, err := gzip.NewReader(f)
if err != nil {
return nil, err
}

h := new(Hnsw)
h.M = readInt32(z)
h.M0 = readInt32(z)
h.efConstruction = readInt32(z)
h.linkMode = readInt32(z)
h.DelaunayType = readInt32(z)
h.LevelMult = readFloat64(z)
h.maxLayer = readInt32(z)
h.enterpoint = uint32(readInt32(z))

h.DistFunc = f32.L2Squared8AVX
h.bitset = bitsetpool.New()

l := readInt32(z)
h.nodes = make([]node, l)

for i := range h.nodes {

l := readInt32(z)
h.nodes[i].p = make([]float32, l)

err = binary.Read(z, binary.LittleEndian, h.nodes[i].p)
if err != nil {
panic(err)
}
h.nodes[i].level = readInt32(z)

l = readInt32(z)
h.nodes[i].friends = make([][]uint32, l)

for j := range h.nodes[i].friends {
l := readInt32(z)
h.nodes[i].friends[j] = make([]uint32, l)
err = binary.Read(z, binary.LittleEndian, h.nodes[i].friends[j])
if err != nil {
panic(err)
}
}

}

z.Close()
f.Close()

return h, nil
}

// Save writes to current index to a gzipped binary data file
func (h *Hnsw) Save(filename string) error {
f, err := os.Create(filename)
if err != nil {
return err
}
z := gzip.NewWriter(f)

writeInt32(h.M, z)
writeInt32(h.M0, z)
writeInt32(h.efConstruction, z)
writeInt32(h.linkMode, z)
writeInt32(h.DelaunayType, z)
writeFloat64(h.LevelMult, z)
writeInt32(h.maxLayer, z)
writeInt32(int(h.enterpoint), z)

l := len(h.nodes)
writeInt32(l, z)

if err != nil {
return err
}
for _, n := range h.nodes {
l := len(n.p)
writeInt32(l, z)
err = binary.Write(z, binary.LittleEndian, []float32(n.p))
if err != nil {
panic(err)
}
writeInt32(n.level, z)

l = len(n.friends)
writeInt32(l, z)
for _, f := range n.friends {
l := len(f)
writeInt32(l, z)
err = binary.Write(z, binary.LittleEndian, f)
if err != nil {
panic(err)
}
}
}

z.Close()
f.Close()

return nil
}

func writeInt32(v int, w io.Writer) {
i := int32(v)
err := binary.Write(w, binary.LittleEndian, &i)
if err != nil {
panic(err)
}
}

func readInt32(r io.Reader) int {
var i int32
err := binary.Read(r, binary.LittleEndian, &i)
if err != nil {
panic(err)
}
return int(i)
}

func writeFloat64(v float64, w io.Writer) {
err := binary.Write(w, binary.LittleEndian, &v)
if err != nil {
panic(err)
}
}

func readFloat64(r io.Reader) (v float64) {
err := binary.Read(r, binary.LittleEndian, &v)
if err != nil {
panic(err)
}
return
}

func (h *Hnsw) getFriends(n uint32, level int) []uint32 {
if len(h.nodes[n].friends) < level+1 {
return make([]uint32, 0)
Expand Down
48 changes: 37 additions & 11 deletions hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,44 @@ import (
"sync/atomic"
"testing"
"time"
)

func TestSIFT(t *testing.T) {

efSearch := []int{1, 2, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 300, 400}
"github.com/stretchr/testify/assert"
)

//prefix := "siftsmall/siftsmall"
prefix := "sift/sift"
//dataSize := 10000
dataSize := 1000000
var prefix = "siftsmall/siftsmall"
var dataSize = 10000
var efSearch = []int{1, 2, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 300, 400}
var queries []Point
var truth [][]uint32

func TestMain(m *testing.M) {
// LOAD QUERIES AND GROUNDTRUTH
fmt.Printf("Loading query records\n")
queries, truth := loadQueriesFromFvec(prefix)
queries, truth = loadQueriesFromFvec(prefix)
os.Exit(m.Run())
}
func TestSaveLoad(t *testing.T) {
h := buildIndex()
testSearch(h)

fmt.Printf("Saving to index.dat\n")
err := h.Save("index.dat")
assert.Nil(t, err)

fmt.Printf("Loading from index.dat\n")
h2, err := Load("index.dat")
assert.Nil(t, err)

fmt.Printf(h2.Stats())
testSearch(h2)
}

func TestSIFT(t *testing.T) {
h := buildIndex()
testSearch(h)
}

func buildIndex() *Hnsw {
// BUILD INDEX
var p Point
p = make([]float32, 128)
Expand All @@ -39,9 +62,12 @@ func TestSIFT(t *testing.T) {
buildFromChan(h, points)
buildStop := time.Since(buildStart)
fmt.Printf("Index build in %v\n", buildStop)

fmt.Printf(h.Stats())

return h
}

func testSearch(h *Hnsw) {
// SEARCH
for _, ef := range efSearch {
fmt.Printf("Now searching with ef=%v\n", ef)
Expand Down Expand Up @@ -204,7 +230,7 @@ func loadDataFromFvec(prefix string, points chan job) {
}
points <- job{p: vec, id: uint32(count)}
count++
if count%10000 == 0 {
if count%1000 == 0 {
fmt.Printf("Read %v records\n", count)
}
}
Expand Down

0 comments on commit 3b020fb

Please sign in to comment.