forked from torch/threads
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-threads-shared.lua
111 lines (94 loc) · 2.38 KB
/
test-threads-shared.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
require 'torch'
local threads = require 'threads'
local status, tds = pcall(require, 'tds')
tds = status and tds or nil
local nthread = 4
local njob = 10
local msg = "hello from a satellite thread"
threads.Threads.serialization('threads.sharedserialize')
local x = {}
local xh = tds and tds.hash() or {}
local xs = {}
local z = tds and tds.hash() or {}
local D = 10
local K = tds and 100000 or 100 -- good luck in non-shared (30M)
for i=1,njob do
x[i] = torch.ones(D)
xh[i] = torch.ones(D)
xs[i] = torch.FloatStorage(D):fill(1)
for j=1,K do
z[(i-1)*K+j] = "blah" .. i .. j
end
end
collectgarbage()
collectgarbage()
print('GO')
local pool = threads.Threads(
nthread,
function(threadIdx)
pcall(require, 'tds')
print('starting a new thread/state number:', threadIdx)
gmsg = msg -- we copy here an upvalue of the main thread
end
)
local jobdone = 0
for i=1,njob do
pool:addjob(
function()
assert(x[i]:sum() == D)
assert(xh[i]:sum() == D)
assert(torch.FloatTensor(xs[i]):sum() == D)
for j=1,K do
assert(z[(i-1)*K+j] == "blah" .. i .. j)
end
x[i]:add(1)
xh[i]:add(1)
torch.FloatTensor(xs[i]):add(1)
print(string.format('%s -- thread ID is %x', gmsg, __threadid))
collectgarbage()
collectgarbage()
return __threadid
end,
function(id)
print(string.format("task %d finished (ran on thread ID %x)", i, id))
jobdone = jobdone + 1
end
)
end
for i=1,njob do
pool:addjob(
function()
collectgarbage()
collectgarbage()
end
)
end
pool:synchronize()
print(string.format('%d jobs done', jobdone))
pool:terminate()
-- did we do the job in shared mode?
for i=1,njob do
assert(x[i]:sum() == 2*D)
assert(xh[i]:sum() == 2*D)
assert(torch.FloatTensor(xs[i]):sum() == 2*D)
end
-- serialize and zero x
local str = torch.serialize(x)
local strh = torch.serialize(xh)
local strs = torch.serialize(xs)
for i=1,njob do
x[i]:zero()
xh[i]:zero()
xs[i]:fill(0)
end
-- dude, check that unserialized x does not point on x
local y = torch.deserialize(str)
local yh = torch.deserialize(strh)
local ys = torch.deserialize(strs)
for i=1,njob do
assert(y[i]:sum() == 2*D)
assert(yh[i]:sum() == 2*D)
assert(torch.FloatTensor(ys[i]):sum() == 2*D)
end
pool:terminate()
print('PASSED')