-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.go
136 lines (122 loc) · 3.53 KB
/
predict.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package trietree
import (
"unicode/utf8"
)
// Prediction is identifier of a key.
type Prediction struct {
Start int // Start is start index of key in query.
End int // End is end index of key in query.
ID int // ID is for edge node identifier.
}
// PredictionIter is the iterator of Prediction.
type PredictionIter func() *Prediction
// PredictIter returns an iterator function PredictionIter, which enumerates
// Prediction: key suggestions that match the query in the tree.
func (dt *DTree) PredictIter(query string) PredictionIter {
return predictIter[*DNode](dt, query)
}
// PredictIter returns an iterator function PredictionIter, which enumerates
// Prediction: key suggestions that match the query in the tree.
func (st *STree) PredictIter(query string) PredictionIter {
return predictIter[int](st, query)
}
type predictableTree[T comparable] interface {
root() T
nextNode(T, rune) T
nodeId(T) int
nodeLevel(T) int
nodeFail(T) T
}
// methods DTree satisfies predictableTree[*DNode]
func (dt *DTree) root() *DNode { return &dt.Root }
func (dt *DTree) nodeId(n *DNode) int { return n.EdgeID }
func (dt *DTree) nodeLevel(n *DNode) int { return n.Level }
func (dt *DTree) nodeFail(n *DNode) *DNode { return n.Failure }
// methods STree satisfies predictableTree[int]
func (st *STree) root() int { return 0 }
func (st *STree) nodeId(n int) int { return st.Nodes[n].EdgeID }
func (st *STree) nodeLevel(n int) int { return st.Levels[st.nodeId(n)-1] }
func (st *STree) nodeFail(n int) int { return st.Nodes[n].Fail }
type traverser[T comparable] struct {
tree predictableTree[T]
query string
pivot T
index int
}
func newTraverser[T comparable](tree predictableTree[T], query string) traverser[T] {
return traverser[T]{
tree: tree,
query: query,
pivot: tree.root(),
index: 0,
}
}
// next consumes a rune from query, and determine next node to travese tree.
// this returns next node, and tail index of last parsed rune in query.
func (tr *traverser[T]) next() (node T, end int, valid bool) {
var zero T
if tr.query == "" {
return zero, 0, false
}
r, sz := utf8.DecodeRuneInString(tr.query)
if sz == 0 {
return zero, 0, false
}
tr.query = tr.query[sz:]
tr.index += sz
tr.pivot = tr.tree.nextNode(tr.pivot, r)
return tr.pivot, tr.index, true
}
func (tr *traverser[T]) close() {
tr.query = ""
}
// trailingIndex returns the index of the n'th character from the end of string s.
func trailingIndex(s string, n int) int {
x := len(s)
for n > 0 && x > 0 {
_, sz := utf8.DecodeLastRuneInString(s[:x])
if sz == 0 {
break
}
x -= sz
n--
}
return x
}
func predictIter[T comparable](tree predictableTree[T], query string) func() *Prediction {
var (
tr = newTraverser[T](tree, query)
req = true
node T
end int
)
return func() *Prediction {
//log.Printf("predictIter: req=%t end=%d node=%+v", req, end, node)
var p *Prediction
for p == nil {
if req {
//log.Printf(" next: tr{ q=%s x=%d p=%+v }", tr.query, tr.index, tr.pivot)
var valid bool
node, end, valid = tr.next()
//log.Printf(" end=%d valid=%t node=%+v", end, valid, node)
if !valid {
tr.close()
return nil
}
req = false
}
for !req && p == nil {
id := tree.nodeId(node)
//log.Printf(" id=%d node=%+v", id, node)
if id > 0 {
st := trailingIndex(query[:end], tree.nodeLevel(node))
p = &Prediction{Start: st, End: end, ID: id}
}
req = node == tree.root()
node = tree.nodeFail(node)
}
}
//log.Printf(" return p=%+v node=%+v", p, node)
return p
}
}