forked from szagoruyko/cifar.torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.lua
97 lines (84 loc) · 2.54 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
require 'xlua'
require 'optim'
require 'image'
require 'batchflip'
local Trainer = torch.class('Trainer')
-- opt is dict of:
-- backend
-- model
-- weightDecay
-- momentum
function Trainer.__init(self, opt)
self.opt = opt
self.backend = opt.backend
local model = nn.Sequential()
model:add(nn.BatchFlip():float())
model:add(self:cast(nn.Copy('torch.FloatTensor', torch.type(self:cast(torch.Tensor())))))
model:add(self:cast(dofile('models/'..opt.model..'.lua')))
model:get(2).updateGradInput = function(input) return end
if opt.backend == 'cudnn' then
print('using cudnn')
require 'cudnn'
if opt.cudnnfastest then
print('Using cudnn \'fastest\' mode')
cudnn.fastest = true
cudnn.benchmark = true
end
cudnn.convert(model:get(3), cudnn)
end
self.model = model
print(model)
self.parameters, self.gradParameters = model:getParameters()
self.criterion = self:cast(nn.CrossEntropyCriterion())
self.optimState = {
learningRate = opt.learningRate,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = opt.learningRateDecay,
}
end
function Trainer.cast(self, t)
local backend = self.backend
if backend == 'cuda' then
require 'cunn'
return t:cuda()
elseif backend == 'float' then
return t:float()
elseif backend == 'cl' then
require 'clnn'
return t:cl()
else
error('Unknown backend '.. self.backend)
end
end
function Trainer.trainBatch(self, learningRate, inputs, targets)
local opt = self.opt
local loss = nil
self.optimState.learningRate = learningRate
self.model:training()
self.cutargets = self.cutargets or self:cast(torch.Tensor(targets:size()))
self.cutargets:resize(targets:size())
self.cutargets:copy(targets)
local feval = function(x)
if x ~= self.parameters then self.parameters:copy(x) end
self.gradParameters:zero()
local outputs = self.model:forward(inputs)
loss = self.criterion:forward(outputs, self.cutargets)
local df_do = self.criterion:backward(outputs, self.cutargets)
self.model:backward(inputs, df_do)
return loss, self.gradParameters
end
optim.sgd(feval, self.parameters, self.optimState)
return loss
end
function Trainer.predict(self, inputs)
-- disable flips, dropouts and batch normalization
self.model:evaluate()
local outputs = self.model:forward(inputs)
local _, predictions = outputs:max(2)
return predictions:byte()
end
function Trainer.save(self, filepath)
torch.save(filepath, self.model:get(3):clearState())
collectgarbage(); collectgarbage()
end