-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNnwSmootherLlr.lua
100 lines (82 loc) · 3.32 KB
/
NnwSmootherLlr.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
-- NnwSmootherLlr.lua
-- estimate value using kernel-weighted average of k nearest neighbors
-- API overview
if false then
skwavg = NnwSmootherLlr(allXs, allYs, visible, cache)
ok, estimate = skwavg:estimate(queryIndex, k)
end -- API overview
--------------------------------------------------------------------------------
-- CONSTRUCTOR
--------------------------------------------------------------------------------
local _, parent = torch.class('NnwSmootherLlr', 'NnwSmoother')
function NnwSmootherLlr:__init(allXs, allYs, visible, nncache, kernelName)
local v, isVerbose = makeVerbose(false, 'NnwSmootherLlr:__init')
verify(v, isVerbose,
{{allXs, 'allXs', 'isTensor2D'},
{allYs, 'allYs', 'isTensor1D'},
{visible, 'visible', 'isTensor1D'},
{nncache, 'nncache', 'isTable'}})
assert(kernelName == 'epanechnikov quadratic',
'only kernel supported is epanechnikov quadratic')
assert(torch.typename(nncache) == 'Nncache')
parent.__init(self, allXs, allYs, visible, nncache)
v('self', self)
v('self._nncache', self._nncache)
end -- __init()
--------------------------------------------------------------------------------
-- PUBLIC METHODS
--------------------------------------------------------------------------------
function NnwSmootherLlr:estimate(obsIndex, params)
local v, isVerbose = makeVerbose(false, 'NnwSmootherLlr:estimate')
verify(v, isVerbose,
{{obsIndex, 'obsIndex', 'isIntegerPositive'},
{params, 'params', 'isTable'}})
v('self', self)
affirm.isIntegerPositive(params.k, 'params.k')
affirm.isNumberNonNegative(params.regularizer, 'params.regularizer')
local k = params.k
assert(k <= Nncachebuilder:maxNeighbors())
-- determine distances and lambda
-- NOTE: code is the same as in SmootherKwavg:estimate
local nObs = self._visible:size(1)
local distances = torch.Tensor(nObs):fill(1e100)
local query = self._allXs[obsIndex]
v('query', query)
local sortedNeighborIndices = self._nncache:getLine(obsIndex)
assert(sortedNeighborIndices)
v('sortedNeighborIndices', sortedNeighborIndices)
local found = 0
for i = 1, nObs do
local obsIndex = sortedNeighborIndices[i]
if self._visible[obsIndex] == 1 then
local distance= Nnw.euclideanDistance(self._allXs[obsIndex], query)
distances[i] = distance
if debug == 1 then
v('x', self._allXs[obsIndex])
end
v('i,obsIndex,distance', i, obsIndex, distance)
found = found + 1
if found == k then
lambda = distance
break
end
end
end
v('lambda', lambda)
v('distances', distances)
if lambda == 0 then
return false, 'lambda == 0'
end
local weights = Nnw.weights(distances, lambda)
v('weights', weights)
local ok, estimate = Nnw.estimateLlr(k,
params.regularizer,
sortedNeighborIndices,
self._visible,
weights,
self._allXs[obsIndex]:clone(),
self._allXs,
self._allYs)
v('ok, estimate', ok, estimate)
return ok, estimate
end -- estimate