forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ConfusionMatrix.lua
112 lines (104 loc) · 3.5 KB
/
ConfusionMatrix.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
local ConfusionMatrix = torch.class('nn.ConfusionMatrix')
function ConfusionMatrix:__init(nclasses, classes)
if type(nclasses) == 'table' then
classes = nclasses
nclasses = #classes
end
self.mat = torch.FloatTensor(nclasses,nclasses):zero()
self.valids = torch.FloatTensor(nclasses):zero()
self.nclasses = nclasses
self.totalValid = 0
self.averageValid = 0
self.classes = classes or {}
end
function ConfusionMatrix:add(prediction, target)
if type(prediction) == 'number' then
-- comparing numbers
self.mat[target][prediction] = self.mat[target][prediction] + 1
elseif type(target) == 'number' then
-- prediction is a vector, then target assumed to be an index
local prediction_1d = torch.FloatTensor(self.nclasses):copy(prediction)
local _,prediction = lab.max(prediction_1d)
self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1
else
-- both prediction and target are vectors
local prediction_1d = torch.FloatTensor(self.nclasses):copy(prediction)
local target_1d = torch.FloatTensor(self.nclasses):copy(target)
local _,prediction = lab.max(prediction_1d)
local _,target = lab.max(target_1d)
self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1
end
end
function ConfusionMatrix:zero()
self.mat:zero()
self.valids:zero()
self.totalValid = 0
self.averageValid = 0
end
function ConfusionMatrix:updateValids()
local total = 0
for t = 1,self.nclasses do
self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum()
total = total + self.mat[t][t]
end
self.totalValid = total / self.mat:sum()
self.averageValid = 0
local nvalids = 0
for t = 1,self.nclasses do
if not sys.isNaN(self.valids[t]) then
self.averageValid = self.averageValid + self.valids[t]
nvalids = nvalids + 1
end
end
self.averageValid = self.averageValid / nvalids
end
function ConfusionMatrix:__tostring__()
self:updateValids()
local str = 'ConfusionMatrix:\n'
local nclasses = self.nclasses
str = str .. '['
for t = 1,nclasses do
local pclass = self.valids[t] * 100
pclass = string.format('%2.3f', pclass)
if t == 1 then
str = str .. '['
else
str = str .. ' ['
end
for p = 1,nclasses do
str = str .. '' .. string.format('%8d', self.mat[t][p])
end
if self.classes and self.classes[1] then
if t == nclasses then
str = str .. ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n'
else
str = str .. '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n'
end
else
if t == nclasses then
str = str .. ']] ' .. pclass .. '% \n'
else
str = str .. '] ' .. pclass .. '% \n'
end
end
end
str = str .. ' + average row correct: ' .. (self.averageValid*100) .. '% \n'
str = str .. ' + global correct: ' .. (self.totalValid*100) .. '%'
return str
end
function ConfusionMatrix:write(file)
file:writeObject(self.mat)
file:writeObject(self.valids)
file:writeInt(self.nclasses)
file:writeInt(self.totalValid)
file:writeInt(self.averageValid)
file:writeObject(self.classes)
end
function ConfusionMatrix:read(file)
self.mat = file:readObject()
self.valids = file:readObject()
self.nclasses = file:readInt()
self.totalValid = file:readInt()
self.averageValid = file:readInt()
self.classes = file:readObject()
end