forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Probe.lua
62 lines (57 loc) · 1.95 KB
/
Probe.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
local Probe, parent = torch.class('nn.Probe', 'nn.Module')
function Probe:__init(...)
parent.__init(self)
xlua.unpack_class(self, {...}, 'nn.Probe',
'print/display input/gradients of a network',
{arg='name', type='string', help='unique name to identify probe', req=true},
{arg='print', type='boolean', help='print full tensor', default=false},
{arg='display', type='boolean', help='display tensor', default=false},
{arg='size', type='boolean', help='print tensor size', default=false},
{arg='backw', type='boolean', help='activates probe for backward()', default=false})
end
function Probe:forward(input)
self.output = input
if self.size or self.content then
print('')
print('<probe::' .. self.name .. '> forward()')
if self.content then print(input)
elseif self.size then print(#input)
end
end
if self.display then
self.winf = image.display{image=input, win=self.winf}
end
return self.output
end
function Probe:backward(input, gradOutput)
self.gradInput = gradOutput
if self.backw then
if self.size or self.content then
print('')
print('<probe::' .. self.name .. '> backward()')
if self.content then print(gradOutput)
elseif self.size then print(#gradOutput)
end
end
if self.display then
self.winb = image.display{image=gradOutput, win=self.winb}
end
end
return self.gradInput
end
function Probe:write(file)
parent.write(self, file)
file:writeObject(self.name)
file:writeBool(self.content)
file:writeBool(self.display)
file:writeBool(self.size)
file:writeBool(self.backw)
end
function Probe:read(file)
parent.read(self, file)
self.name = file:readObject()
self.content = file:readBool()
self.display = file:readBool()
self.size = file:readBool()
self.backw = file:readBool()
end