Skip to content

Commit

Permalink
SparseConnection support
Browse files Browse the repository at this point in the history
  • Loading branch information
n-shevko committed Jan 23, 2025
1 parent dd03ea3 commit 0308b27
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 106 deletions.
37 changes: 28 additions & 9 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def update(self) -> None:
(self.connection.wmin != -np.inf).any()
or (self.connection.wmax != np.inf).any()
) and not isinstance(self, NoOp):
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
if self.connection.w.is_sparse:
raise Exception("SparseConnection isn't supported for wmin\\wmax")
else:
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)


class NoOp(LearningRule):
Expand Down Expand Up @@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None:
if self.nu[0].any():
source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w -= update
del source_s, target_x

# Post-synaptic update.
Expand All @@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None:
self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
)
source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += update
del source_x, target_s

super().update()
Expand Down Expand Up @@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None:

# Pre-synaptic update.
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[0] * update

# Post-synaptic update.
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[1] * update

super().update()
Expand Down Expand Up @@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None:
a_minus = torch.tensor(a_minus, device=self.connection.w.device)

# Compute weight update based on the eligibility value of the past timestep.
update = reward * self.eligibility
self.connection.w += self.nu[0] * self.reduction(update, dim=0)
update = self.reduction(reward * self.eligibility, dim=0)
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += self.nu[0] * update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
Expand Down Expand Up @@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None:
self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
self.eligibility_trace += self.eligibility / self.tc_e_trace

update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
if self.connection.w.is_sparse:
update = update.to_sparse()
# Compute weight update.
self.connection.w += (
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
)
self.connection.w += update

# Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
Expand Down Expand Up @@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None:
) * source_x[:, None]

# Compute weight update.
self.connection.w += self.nu[0] * reward * self.eligibility_trace
update = self.nu[0] * reward * self.eligibility_trace
if self.connection.w.is_sparse:
update = update.to_sparse()
self.connection.w += update

super().update()
105 changes: 8 additions & 97 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,13 @@ def update(self, **kwargs) -> None:

mask = kwargs.get("mask", None)
if mask is not None:
if self.w.is_sparse:
raise Exception("Mask isn't supported for SparseConnection")
self.w.masked_fill_(mask, 0)

if self.Dales_rule is not None:
if self.w.is_sparse:
raise Exception("Dales_rule isn't supported for SparseConnection")
# weight that are negative and should be positive are set to 0
self.w[self.w < 0 * self.Dales_rule.to(torch.float)] = 0
# weight that are positive and should be negative are set to 0
Expand Down Expand Up @@ -1947,105 +1951,12 @@ def reset_state_variables(self) -> None:
super().reset_state_variables()


class SparseConnection(AbstractConnection):
class SparseConnection(Connection):
# language=rst
"""
Specifies sparse synapses between one or two populations of neurons.
"""

def __init__(
self,
source: Nodes,
target: Nodes,
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = None,
**kwargs,
) -> None:
# language=rst
"""
Instantiates a :code:`Connection` object with sparse weights.
:param source: A layer of nodes from which the connection originates.
:param target: A layer of nodes to which the connection connects.
:param nu: Learning rate for both pre- and post-synaptic events. It also
accepts a pair of tensors to individualize learning rates of each neuron.
In this case, their shape should be the same size as the connection weights.
:param reduction: Method for reducing parameter updates along the minibatch
dimension.
:param weight_decay: Constant multiple to decay weights by on each iteration.
Keyword arguments:
:param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format
:param float sparsity: Fraction of sparse connections to use.
:param LearningRule update_rule: Modifies connection parameters according to
some rule.
:param float wmin: Minimum allowed value on the connection weights.
:param float wmax: Maximum allowed value on the connection weights.
:param float norm: Total weight per target neuron normalization constant.
"""
super().__init__(source, target, nu, reduction, weight_decay, **kwargs)

w = kwargs.get("w", None)
self.sparsity = kwargs.get("sparsity", None)

assert (
w is not None
and self.sparsity is None
or w is None
and self.sparsity is not None
), 'Only one of "weights" or "sparsity" must be specified'

if w is None and self.sparsity is not None:
i = torch.bernoulli(
1 - self.sparsity * torch.ones(*source.shape, *target.shape)
)
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
v = torch.clamp(
torch.rand(*source.shape, *target.shape), self.wmin, self.wmax
)[i.bool()]
else:
v = (
self.wmin
+ torch.rand(*source.shape, *target.shape) * (self.wmax - self.wmin)
)[i.bool()]
w = torch.sparse.FloatTensor(i.nonzero().t(), v)
elif w is not None and self.sparsity is None:
assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)"
if self.wmin != -np.inf or self.wmax != np.inf:
w = torch.clamp(w, self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)

def compute(self, s: torch.Tensor) -> torch.Tensor:
# language=rst
"""
Compute convolutional pre-activations given spikes using layer weights.
:param s: Incoming spikes.
:return: Incoming spikes multiplied by synaptic weights (with or without
decaying spike activation).
"""
return torch.mm(self.w, s.view(s.shape[1], 1).float()).squeeze(-1)
# return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)

def update(self, **kwargs) -> None:
# language=rst
"""
Compute connection's update rule.
"""

def normalize(self) -> None:
# language=rst
"""
Normalize weights along the first axis according to total weight per target
neuron.
"""

def reset_state_variables(self) -> None:
# language=rst
"""
Contains resetting logic for the connection.
"""
super().reset_state_variables()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.w = Parameter(self.w.to_sparse(), requires_grad=False)

0 comments on commit 0308b27

Please sign in to comment.