-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathKnn-example-1.lua
51 lines (45 loc) · 1.56 KB
/
Knn-example-1.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
-- Knn-example-1.lua
-- smoothing a bunch of points
require 'Knn'
local nObservations = 100
local nDims = 10
-- initialize to random value drawn from Normal(0,1)
local xs = torch.randn(nObservations, nDims)
local ys = torch.randn(nObservations)
local smoothed = torch.Tensor(nObservations)
local knn = Knn() -- create instance of Knn; NOTE: no parameters
local useQueryPoint = true -- use the query point to estimate itself
-- return RMSE for k nearest neighbors
local function computeRmse(k, xs, ys)
local knn = Knn() -- create instance of Knn; NOTE: no parameters
local useQueryPoint = false -- don't use the query point to estimate itself
local sumSquaredErrors = 0
for queryIndex = 1, nObservations do
local ok, value = knn:smooth(xs, ys, queryIndex, k, useQueryPoint)
if not ok then
-- the estimate could not be provided
-- value says why
error('no estimate; queryIndex=' .. queryIndex .. ' reason=' .. value)
else
-- the estimate was provided
-- it's in value
local estimate = value
local error = ys[queryIndex] - estimate
sumSquaredErrors = sumSquaredErrors + error * error
end
end
return math.sqrt(sumSquaredErrors)
end
-- test:
-- increasing k usually reduces the RMSE for a while and then makes it worse
local lowestRmse = math.huge
local lowestK
for k = 1, 40 do
local rmse = computeRmse(k, xs, ys)
if rmse < lowestRmse then
lowestK = k
lowestRmse = rmse
end
print('k, RMSE', k, rmse)
end
print('best k value', lowestK)