-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNncache.lua
131 lines (107 loc) · 4.21 KB
/
Nncache.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
-- Nncache.lua
-- nearest neighbors cache
-- API overview
if false then
-- construction
nnc = Nncache()
-- setter and getter
nnc:setLine(obsIndex, tensor1D)
tensor1D = nnc:getLine(obsIndex) -- may return null
-- apply a function to each key-value pair
local function f(key,value)
end
nnc:apply(f)
-- saving to a file and restoring from one
-- the suffix is determined by the class Nncachebuilder
nnc:save(filePath)
nnc = nnc.load(filePath)
nnc = nnc.loadUsingPrefix(filePathPrefix)
end
--------------------------------------------------------------------------------
-- CONSTRUCTION
--------------------------------------------------------------------------------
torch.class('Nncache')
function Nncache:__init()
self._table = {}
self._lastValuesSize = nil
end
--------------------------------------------------------------------------------
-- PUBLIC CLASS METHODS
--------------------------------------------------------------------------------
function Nncache.load(filePath)
-- return an nnc; error if there is no saved Nncache at the filePath
local v, isVerbose = makeVerbose(true, 'Nncache.load')
verify(v, isVerbose,
{{filePath, 'filePath', 'isString'}})
local nnc = torch.load(filePath,
Nncachebuilder.format())
--v('nnc', nnc)
v('typename', torch.typename(nnc))
assert(torch.typename(nnc) == 'Nncache',
'bad typename = ' .. tostring(torch.typename(nnc)))
-- NOTE: cannot test if each table entry has 256 rows, because the
-- original allXs may have had fewer than 256 observations
return nnc
end -- read
function Nncache.loadUsingPrefix(filePathPrefix)
return Nncache.load(Nncache._filePath(filePathPrefix))
end -- loadUsingPrefix
--------------------------------------------------------------------------------
-- PRIVATE CLASS METHODS
--------------------------------------------------------------------------------
function Nncache._filePath(filePathPrefix)
return filePathPrefix .. Nncachebuilder.mergedFileSuffix()
end -- _filePath
--------------------------------------------------------------------------------
-- PUBLIC INSTANCE METHODS
--------------------------------------------------------------------------------
function Nncache:apply(f)
-- apply a function to each key-value pair
for key, value in pairs(self._table) do
f(key, value)
end
end -- apply
function Nncache:getLine(obsIndex)
-- return line at key or null
local v, isVerbose = makeVerbose(false, 'Nncache:getline')
verify(v, isVerbose,
{{obsIndex, 'obsIndex', 'isIntegerPositive'}})
return self._table[obsIndex]
end -- getline
function Nncache:setLine(obsIndex, values)
-- set the line, checking that it is not already set
local v, isVerbose = makeVerbose(false, 'Nncache:setLine')
verify(v, isVerbose,
{{obsIndex, 'obsIndex', 'isIntegerPositive'},
{values, 'values', 'isTensor1D'}})
v('self', self)
-- check that size of values is same on every call
if self._lastValuesSize then
local newSize = values:size(1)
assert(self._lastValuesSize == newSize,
string.format('cannot change size of values; \n was %s; \n is %s',
tostring(self._lastValuesSize),
tostring(newSize)))
self._lastValuesSize = newSize
else
self._lastValuesSize = values:size(1)
end
-- check that the obsIndex slot has not already been filled
assert(self._table[obsIndex] == nil,
string.format('attempt to set cache line already filled; \nobsIndex ',
tostring(obsIndex)))
self._table[obsIndex] = values
end -- setLine
function Nncache:save(filePath)
-- write to disk by serializing
-- NOTE: if the name of this method were 'write', then the call below
-- to torch.save would call this function recursively. Hence the name
-- of this function.
local v, isVerbose = makeVerbose(false, 'Nncache:write')
v('self', self)
verify(v, isVerbose,
{{filePath, 'filePath', 'isString'}})
v('filePath', filePath)
v('Nncachebuilder.format()', Nncachebuilder.format())
torch.save(filePath, self, Nncachebuilder.format())
end -- write