Skip to content

Commit

Permalink
for virajcz
Browse files Browse the repository at this point in the history
  • Loading branch information
Hananel-Hazan committed Jul 22, 2024
1 parent cb8ce0c commit 1dc91b0
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,3 +1699,120 @@ def set_batch_size(self, batch_size) -> None:
super().set_batch_size(batch_size=batch_size)
self.v = self.rest * torch.ones(batch_size, *self.shape, device=self.v.device)
self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)


class IFNodes_PredInf(Nodes):
# language=rst
"""
Layer of `integrate-and-fire (IF) neurons <http://neuronaldynamics.epfl.ch/online/Ch1.S3.html>`_.
"""

def __init__(
self,
n: Optional[int] = None,
shape: Optional[Iterable[int]] = None,
traces: bool = False,
traces_additive: bool = False,
tc_trace: Union[float, torch.Tensor] = 20.0,
trace_scale: Union[float, torch.Tensor] = 1.0,
sum_input: bool = False,
thresh: Union[float, torch.Tensor] = -52.0,
reset: Union[float, torch.Tensor] = -65.0,
refrac: Union[int, torch.Tensor] = 5,
lbound: float = None,
Fire_only_Once: bool = False,
**kwargs,
) -> None:
# language=rst
"""
Instantiates a layer of IF neurons.
:param n: The number of neurons in the layer.
:param shape: The dimensionality of the layer.
:param traces: Whether to record spike traces.
:param traces_additive: Whether to record spike traces additively.
:param tc_trace: Time constant of spike trace decay.
:param trace_scale: Scaling factor for spike trace.
:param sum_input: Whether to sum all inputs.
:param thresh: Spike threshold voltage.
:param reset: Post-spike reset voltage.
:param refrac: Refractory (non-firing) period of the neuron.
:param lbound: Lower bound of the voltage.
:param Fire_only_Once: Whether to allow only one spike per run until reset.
"""
super().__init__(
n=n,
shape=shape,
traces=traces,
traces_additive=traces_additive,
tc_trace=tc_trace,
trace_scale=trace_scale,
sum_input=sum_input,
)

self.register_buffer(
"reset", torch.tensor(reset, dtype=torch.float)
) # Post-spike reset voltage.
self.register_buffer(
"thresh", torch.tensor(thresh, dtype=torch.float)
) # Spike threshold voltage.
self.register_buffer(
"refrac", torch.tensor(refrac)
) # Post-spike refractory period.
self.register_buffer("v", torch.FloatTensor()) # Neuron voltages.
self.register_buffer(
"refrac_count", torch.FloatTensor()
) # Refractory period counters.

self.Fire_only_Once = Fire_only_Once
if self.Fire_only_Once:
refrac = torch.int.torch.iinfo(torch.int32).max

self.lbound = lbound # Lower bound of voltage.

def forward(self, x: torch.Tensor) -> None:
# language=rst
"""
Runs a single simulation step.
:param x: Inputs to the layer.
"""
# Integrate input voltages.
self.v += (self.refrac_count <= 0).float() * x

# Decrement refractory counters.
self.refrac_count -= self.dt

# Check for spiking neurons.
self.s = self.v >= self.thresh

# Refractoriness and voltage reset.
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)

# Voltage clipping to lower bound.
if self.lbound is not None:
self.v.masked_fill_(self.v < self.lbound, self.lbound)

super().forward(x)

def reset_state_variables(self) -> None:
# language=rst
"""
Resets relevant state variables.
"""
super().reset_state_variables()
self.v.fill_(self.reset) # Neuron voltages.
self.refrac_count.zero_() # Refractory period counters.

def set_batch_size(self, batch_size) -> None:
# language=rst
"""
Sets mini-batch size. Called when layer is added to a network.
:param batch_size: Mini-batch size.
"""
super().set_batch_size(batch_size=batch_size)
self.v = self.reset * torch.ones(batch_size, *self.shape, device=self.v.device)
self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device)

0 comments on commit 1dc91b0

Please sign in to comment.