Skip to content

Commit 7619dde

Browse files
update all tutorials
1 parent ca99b9b commit 7619dde

6 files changed

+205
-244
lines changed

tutorials/01_small_network.ipynb

+23-29
Large diffs are not rendered by default.

tutorials/02_setting_parameters.ipynb

+28-57
Large diffs are not rendered by default.

tutorials/03_gradient.ipynb

+65-67
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"outputs": [],
3030
"source": [
3131
"# I have experienced stability issues with float32.\n",
32-
"from jax.config import config\n",
32+
"from jax import config\n",
3333
"config.update(\"jax_enable_x64\", True)\n",
3434
"config.update(\"jax_platform_name\", \"cpu\")\n",
3535
"\n",
@@ -44,8 +44,6 @@
4444
"metadata": {},
4545
"outputs": [],
4646
"source": [
47-
"import time\n",
48-
"import matplotlib.pyplot as plt\n",
4947
"import numpy as np\n",
5048
"import jax\n",
5149
"import jax.numpy as jnp\n",
@@ -128,7 +126,7 @@
128126
},
129127
{
130128
"cell_type": "code",
131-
"execution_count": 8,
129+
"execution_count": 11,
132130
"id": "ff784bcb",
133131
"metadata": {},
134132
"outputs": [],
@@ -138,7 +136,7 @@
138136
},
139137
{
140138
"cell_type": "code",
141-
"execution_count": 9,
139+
"execution_count": 12,
142140
"id": "222f9a00",
143141
"metadata": {},
144142
"outputs": [],
@@ -148,15 +146,15 @@
148146
},
149147
{
150148
"cell_type": "code",
151-
"execution_count": 10,
149+
"execution_count": 13,
152150
"id": "90affb0c-dc77-47c7-be3c-18e2829cc820",
153151
"metadata": {},
154152
"outputs": [],
155153
"source": [
156154
"for cell_ind in range(5):\n",
157155
" network.cell(cell_ind).branch(1).comp(0.0).record()\n",
158156
" \n",
159-
"current = jx.step_current(i_delay, i_dur, i_amp, time_vec)\n",
157+
"current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\n",
160158
"for stim_ind in range(2):\n",
161159
" network.cell(stim_ind).branch(1).comp(0.0).stimulate(current)"
162160
]
@@ -179,10 +177,18 @@
179177
},
180178
{
181179
"cell_type": "code",
182-
"execution_count": 11,
180+
"execution_count": 14,
183181
"id": "10cb5b1e",
184182
"metadata": {},
185-
"outputs": [],
183+
"outputs": [
184+
{
185+
"name": "stdout",
186+
"output_type": "stream",
187+
"text": [
188+
"Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n"
189+
]
190+
}
191+
],
186192
"source": [
187193
"network.make_trainable(\"radius\")"
188194
]
@@ -197,10 +203,18 @@
197203
},
198204
{
199205
"cell_type": "code",
200-
"execution_count": 12,
206+
"execution_count": 15,
201207
"id": "c90be7f3",
202208
"metadata": {},
203-
"outputs": [],
209+
"outputs": [
210+
{
211+
"name": "stdout",
212+
"output_type": "stream",
213+
"text": [
214+
"Number of newly added trainable parameters: 200. Total number of trainable parameters: 201\n"
215+
]
216+
}
217+
],
204218
"source": [
205219
"network.cell(\"all\").branch(\"all\").comp(\"all\").make_trainable(\"gNa\")"
206220
]
@@ -223,10 +237,18 @@
223237
},
224238
{
225239
"cell_type": "code",
226-
"execution_count": 13,
240+
"execution_count": 16,
227241
"id": "f31901bd",
228242
"metadata": {},
229-
"outputs": [],
243+
"outputs": [
244+
{
245+
"name": "stdout",
246+
"output_type": "stream",
247+
"text": [
248+
"Number of newly added trainable parameters: 1. Total number of trainable parameters: 202\n"
249+
]
250+
}
251+
],
230252
"source": [
231253
"network.make_trainable(\"gS\")"
232254
]
@@ -241,10 +263,18 @@
241263
},
242264
{
243265
"cell_type": "code",
244-
"execution_count": 14,
266+
"execution_count": 17,
245267
"id": "12fe7828",
246268
"metadata": {},
247-
"outputs": [],
269+
"outputs": [
270+
{
271+
"name": "stdout",
272+
"output_type": "stream",
273+
"text": [
274+
"Number of newly added trainable parameters: 6. Total number of trainable parameters: 208\n"
275+
]
276+
}
277+
],
248278
"source": [
249279
"network.GlutamateSynapse(\"all\").make_trainable(\"gS\")"
250280
]
@@ -267,7 +297,7 @@
267297
},
268298
{
269299
"cell_type": "code",
270-
"execution_count": 15,
300+
"execution_count": 18,
271301
"id": "40a48eea",
272302
"metadata": {},
273303
"outputs": [],
@@ -286,7 +316,7 @@
286316
},
287317
{
288318
"cell_type": "code",
289-
"execution_count": 17,
319+
"execution_count": 19,
290320
"id": "4eb3f8f1",
291321
"metadata": {},
292322
"outputs": [],
@@ -312,7 +342,7 @@
312342
},
313343
{
314344
"cell_type": "code",
315-
"execution_count": 24,
345+
"execution_count": 20,
316346
"id": "a29f1ac2",
317347
"metadata": {},
318348
"outputs": [],
@@ -332,7 +362,7 @@
332362
},
333363
{
334364
"cell_type": "code",
335-
"execution_count": 25,
365+
"execution_count": 21,
336366
"id": "f38d61a9",
337367
"metadata": {},
338368
"outputs": [],
@@ -342,7 +372,7 @@
342372
},
343373
{
344374
"cell_type": "code",
345-
"execution_count": 26,
375+
"execution_count": 22,
346376
"id": "9ac97e04",
347377
"metadata": {},
348378
"outputs": [],
@@ -362,22 +392,10 @@
362392
},
363393
{
364394
"cell_type": "code",
365-
"execution_count": 27,
395+
"execution_count": 23,
366396
"id": "d9ccf1b6",
367397
"metadata": {},
368-
"outputs": [
369-
{
370-
"ename": "ModuleNotFoundError",
371-
"evalue": "No module named 'optax'",
372-
"output_type": "error",
373-
"traceback": [
374-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
375-
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
376-
"\u001b[0;32m/var/folders/j9/w7ftvg_16t1f9bgp1cy4bt1r0000gn/T/ipykernel_6746/3781452166.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0moptax\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
377-
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'optax'"
378-
]
379-
}
380-
],
398+
"outputs": [],
381399
"source": [
382400
"import optax"
383401
]
@@ -392,7 +410,7 @@
392410
},
393411
{
394412
"cell_type": "code",
395-
"execution_count": null,
413+
"execution_count": 24,
396414
"id": "710a1545",
397415
"metadata": {},
398416
"outputs": [],
@@ -413,7 +431,7 @@
413431
},
414432
{
415433
"cell_type": "code",
416-
"execution_count": null,
434+
"execution_count": 25,
417435
"id": "800f959e",
418436
"metadata": {},
419437
"outputs": [],
@@ -437,51 +455,31 @@
437455
},
438456
{
439457
"cell_type": "code",
440-
"execution_count": null,
458+
"execution_count": 26,
441459
"id": "9d639efa",
442460
"metadata": {},
443461
"outputs": [],
444462
"source": [
445463
"opt_params = transform.inverse(params)\n",
446-
"optimizer = optax.adam(learning_rate=1e-2)\n",
464+
"optimizer = optax.adam(learning_rate=1e-1)\n",
447465
"opt_state = optimizer.init(opt_params)"
448466
]
449467
},
450468
{
451469
"cell_type": "code",
452-
"execution_count": null,
470+
"execution_count": 27,
453471
"id": "0e4aebd0-283e-4165-8c24-4b6fb811135e",
454472
"metadata": {},
455-
"outputs": [],
456-
"source": [
457-
"epoch_losses = []\n",
458-
"\n",
459-
"for epoch in range(5):\n",
460-
" loss_val, gradient = jitted_grad(opt_params)\n",
461-
" updates, opt_state = optimizer.update(gradient, opt_state)\n",
462-
" opt_params = optax.apply_updates(opt_params, updates)\n",
463-
"\n",
464-
" print(f\"epoch {epoch}, loss {loss_val}\")\n",
465-
" epoch_losses.append(loss_val)\n",
466-
" \n",
467-
"final_params = transform.forward(opt_params)"
468-
]
469-
},
470-
{
471-
"cell_type": "code",
472-
"execution_count": 43,
473-
"id": "134af3e1",
474-
"metadata": {},
475473
"outputs": [
476474
{
477475
"name": "stdout",
478476
"output_type": "stream",
479477
"text": [
480-
"epoch 0, loss -64.97740510297487\n",
481-
"epoch 1, loss -64.98296369502924\n",
482-
"epoch 2, loss -64.98846441030534\n",
483-
"epoch 3, loss -64.99390672049375\n",
484-
"epoch 4, loss -64.99929015505666\n"
478+
"epoch 0, loss -64.97740512171988\n",
479+
"epoch 1, loss -65.03050878143252\n",
480+
"epoch 2, loss -65.07820463321355\n",
481+
"epoch 3, loss -65.12091871011332\n",
482+
"epoch 4, loss -65.15909222980945\n"
485483
]
486484
}
487485
],
@@ -510,9 +508,9 @@
510508
],
511509
"metadata": {
512510
"kernelspec": {
513-
"display_name": "jax",
511+
"display_name": "neurax",
514512
"language": "python",
515-
"name": "jax"
513+
"name": "python3"
516514
},
517515
"language_info": {
518516
"codemirror_mode": {
@@ -524,7 +522,7 @@
524522
"name": "python",
525523
"nbconvert_exporter": "python",
526524
"pygments_lexer": "ipython3",
527-
"version": "3.8.12"
525+
"version": "3.10.11"
528526
}
529527
},
530528
"nbformat": 4,

tutorials/04_groups.ipynb

+27-11
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"outputs": [],
3030
"source": [
3131
"# I have experienced stability issues with float32.\n",
32-
"from jax.config import config\n",
32+
"from jax import config\n",
3333
"config.update(\"jax_enable_x64\", True)\n",
3434
"config.update(\"jax_platform_name\", \"cpu\")\n",
3535
"\n",
@@ -618,7 +618,15 @@
618618
"execution_count": 14,
619619
"id": "3a399a56",
620620
"metadata": {},
621-
"outputs": [],
621+
"outputs": [
622+
{
623+
"name": "stdout",
624+
"output_type": "stream",
625+
"text": [
626+
"Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n"
627+
]
628+
}
629+
],
622630
"source": [
623631
"network.fast_spiking.make_trainable(\"gNa\")"
624632
]
@@ -640,7 +648,7 @@
640648
{
641649
"data": {
642650
"text/plain": [
643-
"[{'gNa': DeviceArray([[0.4]], dtype=float64)}]"
651+
"[{'gNa': Array([[0.4]], dtype=float64)}]"
644652
]
645653
},
646654
"execution_count": 15,
@@ -665,7 +673,15 @@
665673
"execution_count": 16,
666674
"id": "99a6c389",
667675
"metadata": {},
668-
"outputs": [],
676+
"outputs": [
677+
{
678+
"name": "stdout",
679+
"output_type": "stream",
680+
"text": [
681+
"Number of newly added trainable parameters: 3. Total number of trainable parameters: 4\n"
682+
]
683+
}
684+
],
669685
"source": [
670686
"network.cell([0,1,3]).make_trainable(\"axial_resistivity\")"
671687
]
@@ -679,10 +695,10 @@
679695
{
680696
"data": {
681697
"text/plain": [
682-
"[{'gNa': DeviceArray([[0.4]], dtype=float64)},\n",
683-
" {'axial_resistivity': DeviceArray([[5000.],\n",
684-
" [5000.],\n",
685-
" [5000.]], dtype=float64)}]"
698+
"[{'gNa': Array([[0.4]], dtype=float64)},\n",
699+
" {'axial_resistivity': Array([[5000.],\n",
700+
" [5000.],\n",
701+
" [5000.]], dtype=float64)}]"
686702
]
687703
},
688704
"execution_count": 17,
@@ -705,9 +721,9 @@
705721
],
706722
"metadata": {
707723
"kernelspec": {
708-
"display_name": "jax",
724+
"display_name": "neurax",
709725
"language": "python",
710-
"name": "jax"
726+
"name": "python3"
711727
},
712728
"language_info": {
713729
"codemirror_mode": {
@@ -719,7 +735,7 @@
719735
"name": "python",
720736
"nbconvert_exporter": "python",
721737
"pygments_lexer": "ipython3",
722-
"version": "3.8.12"
738+
"version": "3.10.11"
723739
}
724740
},
725741
"nbformat": 4,

0 commit comments

Comments
 (0)