forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CGOptimization.lua
68 lines (64 loc) · 2.83 KB
/
CGOptimization.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
local CG,parent = torch.class('nn.CGOptimization', 'nn.BatchOptimization')
function CG:__init(...)
require 'liblbfgs'
parent.__init(self, ...)
xlua.unpack_class(self, {...},
'CGOptimization', nil,
{arg='maxEvaluation', type='number',
help='maximum nb of function evaluations per pass (0 = no max)', default=0},
{arg='maxIterations', type='number',
help='maximum nb of iterations per pass (0 = no max)', default=0},
{arg='maxLineSearch', type='number',
help='maximum nb of steps in line search', default=20},
{arg='sparsity', type='number',
help='sparsity coef (Orthantwise C)', default=0},
{arg='linesearch', type='string',
help=[[ type of linesearch used:
"morethuente", "m",
"armijo", "a",
"wolfe", "w",
"strong_wolfe", "s"
]],
default='wolfe'},
{arg='momentum', type='string',
help=[[ type of momentum used:
"fletcher-reeves", "fr",
"polack-ribiere", "pr",
"hestens-steifel", "hs",
"gilbert-nocedal", "gn"
]],
default='fletcher-reeves'},
{arg='parallelize', type='number',
help='parallelize onto N cores (experimental!)', default=1}
)
local linesearch = 2
if not (self.linesearch == 'w' or self.linesearch == 'wolfe') then
if self.linesearch == 'm' or self.linesearch == 'morethuente' then
linesearch = 0
elseif self.linesearch == 'a' or self.linesearch == 'armijo' then
linesearch = 1
elseif self.linesearch == 's' or self.linesearch == 'strong_wolfe' then
linesearch = 3
end
end
local momentum = 0
if not (self.momentum == 'fr' or self.momentum == 'fletcher-reeves') then
if self.momentum == 'pr' or self.momentum == 'polack-ribiere' then
momentum = 1
elseif self.momentum == 'hs' or self.momentum == 'hestens-steifel' then
momentum = 2
elseif self.momentum == 'gn' or self.momentum == 'gilbert-nocedal' then
momentum = 3
end
end
-- init CG state
cg.init(self.parameters, self.gradParameters,
self.maxEvaluation, self.maxIterations, self.maxLineSearch,
momentum, linesearch, self.verbose)
end
function CG:optimize()
-- callback for lBFGS
lbfgs.evaluate = self.evaluate
-- the magic function: will update the parameter vector using CG
self.output = cg.run()
end