-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.jl
32 lines (31 loc) · 838 Bytes
/
train.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import Pkg; Pkg.activate("."); Pkg.instantiate();
include(ARGS[1])
include(ARGS[2])
println("Loading questions ...")
Knet.seed!(11131994)
trnqstns = getQdata(o[:dhome],"train")
valqstns = getQdata(o[:dhome],"val")
println("Loading dictionaries ... ")
qvoc,avoc,i2w,i2a = getDicts(o[:dhome],"dic")
sets = []
push!(sets,miniBatch(trnqstns;B=64))
push!(sets,miniBatch(valqstns;B=64))
trnqstns=nothing;
valqstns=nothing;
#MODEL
#gpu(0)
@show arrtype
if o[:mfile] !=nothing && isfile(o[:mfile])
M,Mrun,o = loadmodel(o[:mfile])
else
M = MACNetwork(o);
Mrun = MACNetwork(o);
end
for (wr,wi) in zip(params(Mrun),params(M));
wr.value[:] = wi.value[:]
end
Knet.gc()
#FEATS
trnfeats = loadFeatures(o[:dhome],"train";h5=o[:h5])
valfeats = loadFeatures(o[:dhome],"val";h5=o[:h5])
M,Mrun = train!(M,Mrun,sets,(trnfeats,valfeats),o)