Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update feature_neuron #4

Merged
merged 12 commits into from
Dec 26, 2020
231 changes: 92 additions & 139 deletions examples/tutorial_1_spikegen.ipynb

Large diffs are not rendered by default.

122 changes: 74 additions & 48 deletions examples/tutorial_2_FCN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
{
"cell_type": "markdown",
"source": [
"# # Gradient-based Learning in Spiking Neural Networks\n",
"## Gradient-based Learning in Spiking Neural Networks\n",
"In this tutorial, we'll create a simple 2-layer fully-connected network (FCN) to classify the MNIST dataset.\n",
"We will use the backpropagation through time (BPTT) algorithm to do so.\n",
"We will use the backpropagation through time (BPTT) algorithm to train the network.\n",
"\n",
"If running in Google Colab:\n",
"* Ensure you are connected to GPU by checking Runtime > Change runtime type > Hardware accelerator: GPU\n",
Expand Down Expand Up @@ -107,8 +107,8 @@
"# Temporal Dynamics\n",
"num_steps = 25\n",
"time_step = 1e-3\n",
"tau_mem = 6.5e-4\n",
"tau_syn = 5.5e-4\n",
"tau_mem = 3e-3\n",
"tau_syn = 2.2e-3\n",
"alpha = float(np.exp(-time_step/tau_syn))\n",
"beta = float(np.exp(-time_step/tau_mem))\n",
"\n",
Expand Down Expand Up @@ -188,16 +188,14 @@
"cell_type": "markdown",
"source": [
"## 2. Define Network\n",
"To define our network, we will import two functions from the `snntorch.neuron` module, which contains a series of neuron models and related functions.\n",
"snnTorch treats neurons as activations with recurrent connections, so that it integrates smoothly with PyTorch's pre-existing layer functions.\n",
"* `snntorch.neuron.LIF` is a simple Leaky Integrate and Fire (LIF) neuron. Specifically, it uses Stein's model which assumes instantaneous rise times for synaptic current and membrane potential.\n",
"* `snntorch.neuron.FastSigmoidSurrogate` defines separate forward and backward functions. The forward function is a Heaviside step function for spike generation. The backward function is the derivative of a fast sigmoid function, to ensure continuous differentiability.\n",
"FSS is mostly derived from:\n",
"snnTorch treats neurons as activations with recurrent connections. This allows for smooth integration with PyTorch.\n",
"The `snntorch.neuron` module contains a few useful neuron models and surrogate gradient functions to approximate the derivative of spiking.\n",
"Our network will use one type of neuron model and one surrogate gradient:\n",
"1. `snntorch.neuron.stein` is a basic leaky integrate and fire (LIF) neuron. Specifically, it assumes instantaneous rise times for synaptic current and membrane potential.\n",
"2. `snntorch.neuron.FastSigmoidSurrogate` defines separate forward and backward functions. The forward function is a Heaviside step function for spike generation. The backward function is the derivative of a fast sigmoid function, to ensure continuous differentiability.\n",
"The `FastSigmoidSurrogate` function has been adapted from:\n",
"\n",
">Neftci, E. O., Mostafa, H., and Zenke, F. (2019) Surrogate Gradient Learning in Spiking Neural Networks. https://arxiv.org/abs/1901/09948\n",
"\n",
"`snn.neuron.slope` is a variable that defines the slope of the backward surrogate.\n",
"TO-DO: Include visualisation."
">Neftci, E. O., Mostafa, H., and Zenke, F. (2019) Surrogate Gradient Learning in Spiking Neural Networks. https://arxiv.org/abs/1901/09948"
],
"metadata": {
"collapsed": false,
Expand All @@ -211,10 +209,10 @@
"execution_count": null,
"outputs": [],
"source": [
"from snntorch.neuron import LIF\n",
"from snntorch.neuron import Stein\n",
"from snntorch.neuron import FastSimgoidSurrogate as FSS\n",
"\n",
"spike_fn = FSS.apply\n",
"spike_grad = FSS.apply\n",
"snn.neuron.slope = 50"
],
"metadata": {
Expand All @@ -227,11 +225,36 @@
{
"cell_type": "markdown",
"source": [
"Now we can define our SNN. Defining an instance of `LIF` requires three arguments: 1) the surrogate spiking function, 2) $I_{syn}$ decay rate, $\\alpha$, and 3) $V_{mem}$ decay rate, $\\beta$.\n",
"The surrogate is passed to `spike_grad` and overrides the default gradient of the Heaviside step function.\n",
"If we did not override the default gradient, (zero everywhere, except for $x=1$ where it is technically infinite but clipped to 1 here), then learning would not take place for as long as the neuron was not emitting post-synaptic spikes.\n",
"\n",
"The LIF neuron is simply treated as a recurrent activation. It requires initialization of the post-synaptic spikes `spk1` and `spk2`, the synaptic current `syn1` and `syn2`, and the membrane potential `mem1` and `mem2`.\n",
"`snn.neuron.slope` defines the slope of the backward surrogate.\n",
"\n",
"We will use the final layer spikes and membrane for determining loss and accuracy, so we will record and return their histories in `spk2_rec` and `mem2_rec`.\n",
"TO-DO: Include visualisation."
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"Now we can define our spiking neural network (SNN).\n",
"Creating an instance of the `Stein` neuron requires two arguments and two optional arguments:\n",
"1. $I_{syn}$ decay rate, $\\alpha$,\n",
"2. $V_{mem}$ decay rate, $\\beta$,\n",
"3. the surrogate spiking function, `spike_grad`=`FSS` (*default*: the gradient of the Heaviside function), and\n",
"4. the threshold for spiking, (*default*: 1.0).\n",
"\n",
"snnTorch treats the LIF neuron as a recurrent activation. Therefore, it requires initialization of its internal states.\n",
"For each layer, we initialize the synaptic current `syn1` and `syn2`, the membrane potential `mem1` and `mem2`, and the post-synaptic spikes `spk1` and `spk2` to zero.\n",
"A class method `init_stein` will take care of this.\n",
"\n",
"For rate coding, the final layer of spikes and membrane potential are used to determine accuracy and loss, respectively.\n",
"So their historical values are recorded in `spk2_rec` and `mem2_rec`.\n",
"\n",
"Keep in mind, the dataset we are using is just static MNIST. I.e., it is *not* time-varying.\n",
"Therefore, we pass the same MNIST sample to the input at each time step.\n",
Expand All @@ -256,13 +279,13 @@
"\n",
" # initialize layers\n",
" self.fc1 = nn.Linear(num_inputs, num_hidden)\n",
" self.lif1 = LIF(spike_fn=spike_fn, alpha=alpha, beta=beta)\n",
" self.lif1 = Stein(spike_fn=spike_fn, alpha=alpha, beta=beta)\n",
" self.fc2 = nn.Linear(num_hidden, num_outputs)\n",
" self.lif2 = LIF(spike_fn=spike_fn, alpha=alpha, beta=beta)\n",
" self.lif2 = Stein(spike_fn=spike_fn, alpha=alpha, beta=beta)\n",
"\n",
" def forward(self, x):\n",
" spk1, syn1, mem1 = self.lif1.init_hidden(batch_size, num_hidden)\n",
" spk2, syn2, mem2 = self.lif2.init_hidden(batch_size, num_outputs)\n",
" spk1, syn1, mem1 = self.lif1.init_stein(batch_size, num_hidden)\n",
" spk2, syn2, mem2 = self.lif2.init_stein(batch_size, num_outputs)\n",
"\n",
" spk2_rec = []\n",
" mem2_rec = []\n",
Expand Down Expand Up @@ -291,7 +314,7 @@
"cell_type": "markdown",
"source": [
"## 3. Training\n",
"Time for training! Let's first define a couple of functions to print out test/train accuracy for each minibatch."
"Time for training! Let's first define a couple of functions to print out test/train accuracy."
],
"metadata": {
"collapsed": false,
Expand All @@ -307,10 +330,10 @@
"source": [
"def print_batch_accuracy(data, targets, train=False):\n",
" output, _ = net(data.view(batch_size, -1))\n",
" _, am = output.sum(dim=0).max(1)\n",
" acc = np.mean((targets == am).detach().cpu().numpy())\n",
" _, idx = output.sum(dim=0).max(1)\n",
" acc = np.mean((targets == idx).detach().cpu().numpy())\n",
"\n",
" if train is True:\n",
" if train:\n",
" print(f\"Train Set Accuracy: {acc}\")\n",
" else:\n",
" print(f\"Test Set Accuracy: {acc}\")\n",
Expand All @@ -334,10 +357,12 @@
"cell_type": "markdown",
"source": [
"### 3.1 Optimizer & Loss\n",
"We'll apply the softmax function to the membrane potentials of the output layer in calculating a negative log-likelihood loss.\n",
"The Adam optimizer is used for weight updates.\n",
"\n",
"Accuracy is measured by counting the spikes of the output neurons. The neuron that fires the most frequently will be our predicted class."
"* *Output Activation*: We'll apply the softmax function to the membrane potentials of the output layer, rather than the spikes.\n",
"* *Loss*: This will then be used to calculate the negative log-likelihood loss.\n",
"By encouraging the membrane of the correct neuron class to reach the threshold, we expect that neuron will fire more frequently.\n",
"The loss could be applied to the spike count as well, but the membrane is continuous whereas spike count is discrete.\n",
"* *Optimizer*: The Adam optimizer is used for weight updates.\n",
"* *Accuracy*: Accuracy is measured by counting the spikes of the output neurons. The neuron that fires the most frequently will be our predicted class."
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -481,7 +506,8 @@
{
"cell_type": "markdown",
"source": [
"### 4.2 Test Set Accuracy"
"### 4.2 Test Set Accuracy\n",
"This function just iterates over all minibatches to obtain a measure of accurcy over the full 10,000 samples in the test set."
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -565,8 +591,8 @@
"from snntorch import spikegen\n",
"\n",
"# MNIST to spiking-MNIST\n",
"spike_data, spike_targets = spikegen.rate(data_it, targets_it, num_outputs=num_outputs, num_steps=num_steps,\n",
" gain=1, offset=0, convert_targets=False, temporal_targets=False)"
"spike_data, spike_targets = spikegen.rate(data_it, targets_it, num_outputs=num_outputs, num_steps=num_steps, gain=1,\n",
" offset=0, convert_targets=False, temporal_targets=False)"
],
"metadata": {
"collapsed": false,
Expand Down Expand Up @@ -665,23 +691,24 @@
"execution_count": null,
"outputs": [],
"source": [
"spike_fn = FSS.apply\n",
"snn.neuron.slope = 50\n",
"spike_grad = FSS.apply\n",
"snn.neuron.slope = 50 # The lower the slope, the smoother the gradient\n",
"\n",
"# Define Network\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" # initialize layers\n",
" # initialize layers\n",
" self.fc1 = nn.Linear(num_inputs, num_hidden)\n",
" self.lif1 = LIF(spike_fn=spike_fn, alpha=alpha, beta=beta)\n",
" self.lif1 = Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)\n",
" self.fc2 = nn.Linear(num_hidden, num_outputs)\n",
" self.lif2 = LIF(spike_fn=spike_fn, alpha=alpha, beta=beta)\n",
" self.lif2 = Stein(alpha=alpha, beta=beta, spike_grad=spike_grad)\n",
"\n",
" def forward(self, x):\n",
" spk1, syn1, mem1 = self.lif1.init_hidden(batch_size, num_hidden)\n",
" spk2, syn2, mem2 = self.lif2.init_hidden(batch_size, num_outputs)\n",
" # Initialize hidden states + output spike at t=0\n",
" spk1, syn1, mem1 = self.lif1.init_stein(batch_size, num_hidden)\n",
" spk2, syn2, mem2 = self.lif2.init_stein(batch_size, num_outputs)\n",
"\n",
" spk2_rec = []\n",
" mem2_rec = []\n",
Expand Down Expand Up @@ -724,14 +751,12 @@
"execution_count": null,
"outputs": [],
"source": [
"# Print batch accuracy function\n",
"\n",
"def print_batch_accuracy(data, targets, train=False):\n",
" output, _ = net(data.view(num_steps, batch_size, -1))\n",
" _, am = output.sum(dim=0).max(1)\n",
" acc = np.mean((targets == am). detach().cpu().numpy())\n",
" _, idx = output.sum(dim=0).max(1)\n",
" acc = np.mean((targets == idx).detach().cpu().numpy())\n",
"\n",
" if train is True:\n",
" if train:\n",
" print(f\"Train Set Accuracy: {acc}\")\n",
" else:\n",
" print(f\"Test Set Accuracy: {acc}\")\n",
Expand Down Expand Up @@ -961,7 +986,8 @@
{
"cell_type": "markdown",
"source": [
"That's all for now! Next time, we'll make a few tiny modifications to show how we can crank accuracy up even further by using spiking convolutional layers."
"That's all for now!\n",
"Next time, we'll introduce how to use spiking convolutional layers to improve accuracy."
],
"metadata": {
"collapsed": false
Expand All @@ -970,9 +996,9 @@
],
"metadata": {
"kernelspec": {
"name": "pycharm-78d010ff",
"name": "python3",
"language": "python",
"display_name": "PyCharm (snntorch)"
"display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
Expand Down
Loading