-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathextract.lua
55 lines (44 loc) · 1.27 KB
/
extract.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
require 'xlua'
require 'json'
file = io.open(opt.outFile,'w')
function extract()
cutorch.synchronize()
for i=1, nExtract/opt.batchSize do -- nExtract is set in data.lua
collectgarbage()
xlua.progress(i, nExtract/opt.batchSize)
local indexStart = (i-1) * opt.batchSize + 1
local indexEnd = (indexStart + opt.batchSize - 1)
donkeys:addjob(
-- work to be done by donkey thread
function()
local inputs, ids = extractLoader:get(indexStart, indexEnd)
return inputs, ids
end,
-- callback that is run in the main thread once the work is done
extractBatch
)
end
donkeys:synchronize()
cutorch.synchronize()
end -- of extract()
local inputs = torch.CudaTensor()
function jsonStringFromCudaTensor(c)
t = {}
for i = 1,c:size(1) do
t[i] = c[i]
end
return json.encode(t)
end
function extractBatch(inputsCPU, idsCPU)
inputs:resize(inputsCPU:size()):copy(inputsCPU)
local outputs = model:forward(inputs)
collectgarbage()
for i = 1,outputs:size(1) do
local jsonString = jsonStringFromCudaTensor(outputs[i])
local id = idsCPU[i]
file:write(id.."\t"..jsonString.."\n")
file:flush()
end
collectgarbage()
cutorch.synchronize()
end