-
Notifications
You must be signed in to change notification settings - Fork 12
/
lshensemble.go
178 lines (162 loc) · 5.8 KB
/
lshensemble.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
package lshensemble
import (
"errors"
"fmt"
"sync"
"time"
cmap "github.com/orcaman/concurrent-map"
)
type param struct {
k int
l int
}
// Partition represents a domain size partition in the LSH Ensemble index.
type Partition struct {
Lower int `json:"lower"`
Upper int `json:"upper"`
}
// Lsh interface is implemented by LshForst and LshForestArray.
type Lsh interface {
// Add addes a new key into the index, it won't be searchable
// until the next time Index() is called since the add.
Add(key interface{}, sig []uint64)
// Index makes all keys added so far searchable.
Index()
// Query searches the index given a minhash signature, and
// the LSH parameters k and l. Result keys will be written to
// the channel out.
// Closing channel done will cancels the query execution.
Query(sig []uint64, k, l int, out chan<- interface{}, done <-chan struct{})
// OptimalKL computes the optimal LSH parameters k and l given
// x, the index domain size, q, the query domain size, and t,
// the containment threshold. The resulting false positive (fp)
// and false negative (fn) probabilities are returned as well.
OptimalKL(x, q int, t float64) (optK, optL int, fp, fn float64)
}
// LshEnsemble represents an LSH Ensemble index.
type LshEnsemble struct {
Partitions []Partition
lshes []Lsh
maxK int
numHash int
paramCache cmap.ConcurrentMap
}
// NewLshEnsemble initializes a new index consists of MinHash LSH implemented using LshForest.
// numHash is the number of hash functions in MinHash.
// maxK is the maximum value for the MinHash parameter K - the number of hash functions per "band".
// initSize is the initial size of underlying hash tables to allocate.
func NewLshEnsemble(parts []Partition, numHash, maxK, initSize int) *LshEnsemble {
lshes := make([]Lsh, len(parts))
for i := range lshes {
lshes[i] = NewLshForest(maxK, numHash/maxK, initSize)
}
return &LshEnsemble{
lshes: lshes,
Partitions: parts,
maxK: maxK,
numHash: numHash,
paramCache: cmap.New(),
}
}
// NewLshEnsemblePlus initializes a new index consists of MinHash LSH implemented using LshForestArray.
// numHash is the number of hash functions in MinHash.
// maxK is the maximum value for the MinHash parameter K - the number of hash functions per "band".
// initSize is the initial size of underlying hash tables to allocate.
func NewLshEnsemblePlus(parts []Partition, numHash, maxK, initSize int) *LshEnsemble {
lshes := make([]Lsh, len(parts))
for i := range lshes {
lshes[i] = NewLshForestArray(maxK, numHash, initSize)
}
return &LshEnsemble{
lshes: lshes,
Partitions: parts,
maxK: maxK,
numHash: numHash,
paramCache: cmap.New(),
}
}
// Add a new domain to the index given its partition ID - the index of the partition.
// The added domain won't be searchable until the Index() function is called.
func (e *LshEnsemble) Add(key interface{}, sig []uint64, partInd int) {
e.lshes[partInd].Add(key, sig)
}
// Prepare adds a new domain to the index given its size, and partition will
// be selected automatically. It could be more efficient to use Add().
// The added domain won't be searchable until the Index() function is called.
func (e *LshEnsemble) Prepare(key interface{}, sig []uint64, size int) error {
for i := range e.Partitions {
if size >= e.Partitions[i].Lower && size <= e.Partitions[i].Upper {
e.Add(key, sig, i)
break
}
}
return errors.New("No matching partition found")
}
// Index makes all added domains searchable.
func (e *LshEnsemble) Index() {
for i := range e.lshes {
e.lshes[i].Index()
}
}
// Query returns the candidate domain keys in a channel.
// This function is given the MinHash signature of the query domain, sig, the domain size,
// the containment threshold, and a cancellation channel.
// Closing channel done will cancel the query execution.
// The query signature must be generated using the same seed as the signatures of the indexed domains,
// and have the same number of hash functions.
func (e *LshEnsemble) Query(sig []uint64, size int, threshold float64, done <-chan struct{}) <-chan interface{} {
params := e.computeParams(size, threshold)
return e.queryWithParam(sig, params, done)
}
// QueryTimed is similar to Query, returns the candidate domain keys in a slice as well as the running time.
func (e *LshEnsemble) QueryTimed(sig []uint64, size int, threshold float64) (result []interface{}, dur time.Duration) {
// Compute the optimal k and l for each partition
params := e.computeParams(size, threshold)
result = make([]interface{}, 0)
done := make(chan struct{})
defer close(done)
start := time.Now()
for key := range e.queryWithParam(sig, params, done) {
result = append(result, key)
}
dur = time.Since(start)
return result, dur
}
func (e *LshEnsemble) queryWithParam(sig []uint64, params []param, done <-chan struct{}) <-chan interface{} {
// Collect candidates from all partitions
keyChan := make(chan interface{})
var wg sync.WaitGroup
wg.Add(len(e.lshes))
for i := range e.lshes {
go func(lsh Lsh, k, l int) {
lsh.Query(sig, k, l, keyChan, done)
wg.Done()
}(e.lshes[i], params[i].k, params[i].l)
}
go func() {
wg.Wait()
close(keyChan)
}()
return keyChan
}
// Compute the optimal k and l for each partition
func (e *LshEnsemble) computeParams(size int, threshold float64) []param {
params := make([]param, len(e.Partitions))
for i, p := range e.Partitions {
x := p.Upper
key := cacheKey(x, size, threshold)
if cached, exist := e.paramCache.Get(key); exist {
params[i] = cached.(param)
} else {
optK, optL, _, _ := e.lshes[i].OptimalKL(x, size, threshold)
computed := param{optK, optL}
e.paramCache.Set(key, computed)
params[i] = computed
}
}
return params
}
// Make a cache key with threshold precision to 2 decimal points
func cacheKey(x, q int, t float64) string {
return fmt.Sprintf("%.8x %.8x %.2f", x, q, t)
}