Skip to content

Commit

Permalink
Merge pull request #111 from ridgerchu/master
Browse files Browse the repository at this point in the history
Integrate ATan Surrogate function.
  • Loading branch information
jeshraghian authored Jun 17, 2022
2 parents 0409f82 + 97e4350 commit 25bcd96
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions snntorch/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,29 @@ def inner(x):

return inner

class ATan(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, alpha=2.0):
ctx.save_for_backward(input_)
ctx.alpha = alpha
out = (input_ > 0).float()
return out

@staticmethod
def backward(ctx, grad_output):
(input_,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad = ctx.alpha / 2 / (1 + (math.pi / 2 * ctx.alpha * input_).pow_(2)) * grad_input
return grad, None

def atan(alpha = 2.0):
alpha = alpha

def inner(x):
return ATan.apply(x, alpha)

return inner


class Sigmoid(torch.autograd.Function):
"""
Expand Down

0 comments on commit 25bcd96

Please sign in to comment.