-
Notifications
You must be signed in to change notification settings - Fork 92
/
catballotbox.go
119 lines (96 loc) · 2.52 KB
/
catballotbox.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package CloudForest
import (
"sync"
)
//CatBallot is used insideof CatBallotBox to record catagorical votes in a thread safe
//manner.
type CatBallot struct {
Mutex sync.Mutex
Map map[int]float64
}
//NewCatBallot returns a pointer to an initalized CatBallot with a 0 size Map.
func NewCatBallot() (cb *CatBallot) {
cb = new(CatBallot)
cb.Map = make(map[int]float64, 0)
return
}
//CatBallotBox keeps track of votes by trees in a thread safe manner.
type CatBallotBox struct {
*CatMap
Box []*CatBallot
}
//NewCatBallotBox builds a new ballot box for the number of cases specified by "size".
func NewCatBallotBox(size int) *CatBallotBox {
bb := CatBallotBox{
&CatMap{make(map[string]int),
make([]string, 0, 0)},
make([]*CatBallot, 0, size)}
for i := 0; i < size; i++ {
bb.Box = append(bb.Box, NewCatBallot())
}
return &bb
}
//Vote registers a vote that case "casei" should be predicted to be the
//category "pred".
func (bb *CatBallotBox) Vote(casei int, pred string, weight float64) {
predn := bb.CatToNum(pred)
bb.Box[casei].Mutex.Lock()
if _, ok := bb.Box[casei].Map[predn]; !ok {
bb.Box[casei].Map[predn] = 0
}
bb.Box[casei].Map[predn] = bb.Box[casei].Map[predn] + weight
bb.Box[casei].Mutex.Unlock()
}
//Tally tallies the votes for the case specified by i as
//if it is a Categorical or boolean feature. Ie it returns the mode
//(the most frequent value) of all votes.
func (bb *CatBallotBox) Tally(i int) (predicted string) {
predictedn := 0
votes := 0.0
bb.Box[i].Mutex.Lock()
for k, v := range bb.Box[i].Map {
if v > votes {
predictedn = k
votes = v
}
}
bb.Box[i].Mutex.Unlock()
if votes > 0 {
predicted = bb.Back[predictedn]
} else {
predicted = "NA"
}
return
}
/*
TallyError returns the balanced classification error for categorical features.
1 - sum((sum(Y(xi)=Y'(xi))/|xi|))
where
Y are the labels
Y' are the estimated labels
xi is the set of samples with the ith actual label
Case for which the true category is not known are ignored.
*/
func (bb *CatBallotBox) TallyError(feature Feature) (e float64) {
catfeature := feature.(CatFeature)
ncats := catfeature.NCats()
correct := make([]int, ncats)
total := make([]int, ncats)
e = 0.0
for i := 0; i < feature.Length(); i++ {
value := catfeature.Geti(i)
predicted := bb.Tally(i)
if !feature.IsMissing(i) {
total[value]++
if catfeature.NumToCat(value) == predicted {
correct[value]++
}
}
}
for i, ncorrect := range correct {
e += float64(ncorrect) / float64(total[i])
}
e /= float64(ncats)
e = 1.0 - e
return
}