forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Threshold.lua
36 lines (31 loc) · 1.01 KB
/
Threshold.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
local Threshold, parent = torch.class('nn.Threshold','nn.Module')
function Threshold:__init(th,v)
parent.__init(self)
self.threshold = th or 1e-6
self.val = v or 0
if (th and type(th) ~= 'number') or (v and type(v) ~= 'number') then
error(xlua.usage('nn.Threshold',
'a threhold module, if input < threshold, then output = value',
nil,
{type='number', help='threshold'},
{type='number', help='value'}))
end
end
function Threshold:forward(input)
input.nn.Threshold_forward(self, input)
return self.output
end
function Threshold:backward(input, gradOutput)
input.nn.Threshold_backward(self, input, gradOutput)
return self.gradInput
end
function Threshold:write(file)
parent.write(self,file)
file:writeDouble(self.threshold)
file:writeDouble(self.val)
end
function Threshold:read(file)
parent.read(self,file)
self.threshold = file:readDouble()
self.val = file:readDouble()
end