-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.lua
97 lines (92 loc) · 3.21 KB
/
model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
require 'nn'
require 'cunn'
require 'optim'
require 'cudnn'
-- nngraph in main called
local function crit() -- selective structure
if opt.criterion == 'ABS' then
return nn.AbsCriterion() -- cudnn no AbsCriterion
elseif opt.criterion == 'MSE' then
return nn.MSECriterion()
else
print('no such criterion')
return nil
end
end
-- Criterion
criterionAE = nn.ParallelCriterion() -- weighted sum of other criterions
if opt.cGAN then
criterionDisc = nn.ParallelCriterion()
end
for st = 1, opt.nStack do -- 8 hg number add all together
--criterion:add(cudnn.SpatialCrossEntropyCriterion()) -- input nBch x nCls x h x w , average cross entropy
criterionAE:add(crit())
if opt.cGAN then
criterionDisc:add(nn.BCECriterion())
end
end
-- all classes 0 but the true one 1, add sum log(a) together.
-- Create Network
-- If preloading option is set, preload weights from existing models appropriately
-- If model has its own criterion, override.
if opt.retrainG ~= 'none' then -- check only retrainG
assert(paths.filep(opt.retrainG), 'File not found: ' .. opt.retrainG)
print('Loading model from file: ' .. opt.retrainG);
netG = loadDataParallel(opt.retrainG, opt.nGPU)
else
paths.dofile('models/' .. opt.netType .. '.lua') -- hourglass
print('=> Creating model from file: models/' .. opt.netType .. '.lua')
netG = createModelG(opt.nGPU ) -- for the model creation code, check the models/ folder
if opt.backend == 'cudnn' then
require 'cudnn'
cudnn.convert(netG, cudnn) --
elseif opt.backend == 'cunn' then
require 'cunn'
netG = netG:cuda()
elseif opt.backend ~= 'nn' then
error'Unsupported backend'
end
end
print('=> netG')
print(netG)
print('=> Criterion')
print(criterionAE)
-- Convert model to CUDA
print('==> Converting model and criterion to CUDA')
netG:cuda()
criterionAE:cuda()
-- netD
if opt.cGAN then
if opt.retrainD ~= 'none' then -- check only retrainG
assert(paths.filep(opt.retrainD), 'File not found: ' .. opt.retrainD)
print('Loading model from file: ' .. opt.retrainD);
netD = loadDataParallel(opt.retrainD, opt.nGPU)
else
--paths.dofile('models/' .. opt.netType .. '.lua') -- hourglass
paths.dofile('models/cGAN_model.lua')
print('=> Creating model from file: modelss/cGAN_model.lua')
netD_module = defineD_n_layers(opt.inSize[1]+ #opt.jointsIx, opt.inSize[1], opt.ndf, opt.D_nLayers) -- for the model creation code, check the models/ folder
netD = nn.MapTable()
netD:add(netD_module) -- as the single module
if opt.backend == 'cudnn' then -- only for new creation
require 'cudnn'
cudnn.convert(netD, cudnn) --
elseif opt.backend == 'cunn' then
require 'cunn'
netD = netD:cuda()
elseif opt.backend ~= 'nn' then
error'Unsupported backend'
end
end
print('=> netD')
print(netD)
print('=> Criterion disc')
print(criterionDisc)
-- Convert model to CUDA
print('==> Converting model and criterion to CUDA')
netD:cuda()
criterionDisc:cuda()
end
cudnn.fastest = true
cudnn.benchmark = true
collectgarbage()