-
Notifications
You must be signed in to change notification settings - Fork 80
/
predict.lua
67 lines (63 loc) · 2.07 KB
/
predict.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
require 'cutorch'
require 'cunn'
require './SETTINGS'
require './lib/data_augmentation'
require './lib/preprocessing'
require './very_deep_model'
local function predict(file, model, params, test_x)
local BATCH_SIZE = 100
local DA_SIZE = nil
local fp = io.open(file, "w")
local preds = torch.Tensor(test_x:size(1), 10):zero()
if test_x:size(1) % BATCH_SIZE ~= 0 then
error("expect test size % " .. BATCH_SIZE .. " == 0")
end
fp:write("id,label\n")
for i = 1, test_x:size(1), BATCH_SIZE do
local step = 64
if DA_SIZE == nil then
local test_da = data_augmentation(test_x[1])
DA_SIZE = test_da:size()
end
local x = torch.Tensor(BATCH_SIZE, DA_SIZE[1], DA_SIZE[2], DA_SIZE[3], DA_SIZE[4])
local index = torch.LongTensor(BATCH_SIZE, DA_SIZE[1])
for j = 1, BATCH_SIZE do
x[j]:copy(data_augmentation(test_x[i + j - 1]))
index[j]:fill(i + j - 1)
end
x = x:view(BATCH_SIZE * DA_SIZE[1], 3, 24, 24)
index = index:view(BATCH_SIZE * DA_SIZE[1])
preprocessing(x, params)
x = x:cuda()
for j = 1, x:size(1), step do
local batch = torch.Tensor(step, x:size(2), x:size(3), x:size(4)):zero()
local n = step
if j + n > x:size(1) then
n = 1 + n - ((j + n) - x:size(1))
end
batch:narrow(1, 1, n):copy(x:narrow(1, j, n))
batch = batch:cuda()
local z = model:forward(batch):float()
-- averaging
for l = 1, n do
preds[index[j + l -1]] = preds[index[j + l -1]] + z[l]
end
end
for j = 1, BATCH_SIZE do
local max_v, max_i = preds[i + j - 1]:max(1)
fp:write(string.format("%d,%s\n", i + j -1, ID2LABEL[max_i[1]]))
end
xlua.progress(i, test_x:size(1))
collectgarbage()
end
xlua.progress(test_x:size(1), test_x:size(1))
fp:close()
end
local function prediction()
local x = torch.load(string.format("%s/test_x.bin", DATA_DIR))
local model = torch.load("models/very_deep_20.model"):cuda()
local params = torch.load("models/preprocessing_params.bin")
model:evaluate()
predict("./submission.txt", model, params, x)
end
prediction()