-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNnwEstimatorLlr.lua
87 lines (73 loc) · 3.26 KB
/
NnwEstimatorLlr.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
-- NnwEstimatorLlr.lua
-- estimate value using local linear regression of k nearest neighbors
-- API overview
if false then
llr = NnwEstimatorLlr(xs, ys)
-- estimate using k nearest neighbors
ok, estimate = llr:estimate(query, k)
end
--------------------------------------------------------------------------------
-- CONSTRUCTOR
--------------------------------------------------------------------------------
local _, parent = torch.class('NnwEstimatorLlr', 'NnwEstimator')
function NnwEstimatorLlr:__init(xs, ys, kernelName)
local v, isVerbose = makeVerbose(false, 'NnwEstimatorLlr:__init')
verify(v, isVerbose,
{{xs, 'xs', 'isTensor2D'},
{ys, 'ys', 'isTensor1D'},
{kernelName, 'kernelName', 'isString'}})
assert(kernelName == 'epanechnikov quadratic',
'only kernel supported is epanechnikov quadratic')
parent.__init(self, xs, ys)
end -- __init()
--------------------------------------------------------------------------------
-- PUBLIC METHODS
--------------------------------------------------------------------------------
function NnwEstimatorLlr:estimate(query, params)
-- estimate y for a new query point using the Euclidean distance
-- ARGS:
-- query : 1D Tensor
-- params : table
-- params.k : integer > 0, number of neighbors
-- params.regularizer : number >= 0
-- RESULTS:
-- true, estimate : estimate is the estimate for the query
-- estimate is a number
-- false, reason : no estimate was produced
-- reason is a string explaining why
local v, isVerbose = makeVerbose(false, 'NnwEstimatorLlr:estimate')
if isVerbose then print('*******************************************') end
verify(v, isVerbose,
{{query, 'query', 'isTensor1D'},
{params, 'params', 'isTable'}})
affirm.isIntegerPositive(params.k, 'params.k')
affirm.isNumberNonNegative(params.regularizer, 'params.regularizer')
local k = params.k
v('self', self)
local nObs = self._ys:size(1)
assert(k <= nObs,
string.format('k (=%s) exceeds number of observations (=%d)',
tostring(k), nObs))
local sortedDistances, sortedNeighborIndices = Nnw.nearest(self._xs,
query)
v('sortedDistances', sortedDistances)
v('sortedNeighborIndices', sortedNeighborIndices)
local lambda = sortedDistances[k]
local weights = Nnw.weights(sortedDistances, lambda)
v('lambda', lambda)
v('weights', weights)
local visible = torch.Tensor(nObs):fill(1)
local ok, estimate = Nnw.estimateLlr(k,
params.regularizer,
sortedNeighborIndices,
visible,
weights,
query,
self._xs,
self._ys)
v('ok,estimate', ok, estimate)
return ok, estimate
end -- estimate()
--------------------------------------------------------------------------------
-- PRIVATE METHODS (NONE)
--------------------------------------------------------------------------------