|
29 | 29 | "outputs": [],
|
30 | 30 | "source": [
|
31 | 31 | "# I have experienced stability issues with float32.\n",
|
32 |
| - "from jax.config import config\n", |
| 32 | + "from jax import config\n", |
33 | 33 | "config.update(\"jax_enable_x64\", True)\n",
|
34 | 34 | "config.update(\"jax_platform_name\", \"cpu\")\n",
|
35 | 35 | "\n",
|
|
44 | 44 | "metadata": {},
|
45 | 45 | "outputs": [],
|
46 | 46 | "source": [
|
47 |
| - "import time\n", |
48 |
| - "import matplotlib.pyplot as plt\n", |
49 | 47 | "import numpy as np\n",
|
50 | 48 | "import jax\n",
|
51 | 49 | "import jax.numpy as jnp\n",
|
|
128 | 126 | },
|
129 | 127 | {
|
130 | 128 | "cell_type": "code",
|
131 |
| - "execution_count": 8, |
| 129 | + "execution_count": 11, |
132 | 130 | "id": "ff784bcb",
|
133 | 131 | "metadata": {},
|
134 | 132 | "outputs": [],
|
|
138 | 136 | },
|
139 | 137 | {
|
140 | 138 | "cell_type": "code",
|
141 |
| - "execution_count": 9, |
| 139 | + "execution_count": 12, |
142 | 140 | "id": "222f9a00",
|
143 | 141 | "metadata": {},
|
144 | 142 | "outputs": [],
|
|
148 | 146 | },
|
149 | 147 | {
|
150 | 148 | "cell_type": "code",
|
151 |
| - "execution_count": 10, |
| 149 | + "execution_count": 13, |
152 | 150 | "id": "90affb0c-dc77-47c7-be3c-18e2829cc820",
|
153 | 151 | "metadata": {},
|
154 | 152 | "outputs": [],
|
155 | 153 | "source": [
|
156 | 154 | "for cell_ind in range(5):\n",
|
157 | 155 | " network.cell(cell_ind).branch(1).comp(0.0).record()\n",
|
158 | 156 | " \n",
|
159 |
| - "current = jx.step_current(i_delay, i_dur, i_amp, time_vec)\n", |
| 157 | + "current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\n", |
160 | 158 | "for stim_ind in range(2):\n",
|
161 | 159 | " network.cell(stim_ind).branch(1).comp(0.0).stimulate(current)"
|
162 | 160 | ]
|
|
179 | 177 | },
|
180 | 178 | {
|
181 | 179 | "cell_type": "code",
|
182 |
| - "execution_count": 11, |
| 180 | + "execution_count": 14, |
183 | 181 | "id": "10cb5b1e",
|
184 | 182 | "metadata": {},
|
185 |
| - "outputs": [], |
| 183 | + "outputs": [ |
| 184 | + { |
| 185 | + "name": "stdout", |
| 186 | + "output_type": "stream", |
| 187 | + "text": [ |
| 188 | + "Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n" |
| 189 | + ] |
| 190 | + } |
| 191 | + ], |
186 | 192 | "source": [
|
187 | 193 | "network.make_trainable(\"radius\")"
|
188 | 194 | ]
|
|
197 | 203 | },
|
198 | 204 | {
|
199 | 205 | "cell_type": "code",
|
200 |
| - "execution_count": 12, |
| 206 | + "execution_count": 15, |
201 | 207 | "id": "c90be7f3",
|
202 | 208 | "metadata": {},
|
203 |
| - "outputs": [], |
| 209 | + "outputs": [ |
| 210 | + { |
| 211 | + "name": "stdout", |
| 212 | + "output_type": "stream", |
| 213 | + "text": [ |
| 214 | + "Number of newly added trainable parameters: 200. Total number of trainable parameters: 201\n" |
| 215 | + ] |
| 216 | + } |
| 217 | + ], |
204 | 218 | "source": [
|
205 | 219 | "network.cell(\"all\").branch(\"all\").comp(\"all\").make_trainable(\"gNa\")"
|
206 | 220 | ]
|
|
223 | 237 | },
|
224 | 238 | {
|
225 | 239 | "cell_type": "code",
|
226 |
| - "execution_count": 13, |
| 240 | + "execution_count": 16, |
227 | 241 | "id": "f31901bd",
|
228 | 242 | "metadata": {},
|
229 |
| - "outputs": [], |
| 243 | + "outputs": [ |
| 244 | + { |
| 245 | + "name": "stdout", |
| 246 | + "output_type": "stream", |
| 247 | + "text": [ |
| 248 | + "Number of newly added trainable parameters: 1. Total number of trainable parameters: 202\n" |
| 249 | + ] |
| 250 | + } |
| 251 | + ], |
230 | 252 | "source": [
|
231 | 253 | "network.make_trainable(\"gS\")"
|
232 | 254 | ]
|
|
241 | 263 | },
|
242 | 264 | {
|
243 | 265 | "cell_type": "code",
|
244 |
| - "execution_count": 14, |
| 266 | + "execution_count": 17, |
245 | 267 | "id": "12fe7828",
|
246 | 268 | "metadata": {},
|
247 |
| - "outputs": [], |
| 269 | + "outputs": [ |
| 270 | + { |
| 271 | + "name": "stdout", |
| 272 | + "output_type": "stream", |
| 273 | + "text": [ |
| 274 | + "Number of newly added trainable parameters: 6. Total number of trainable parameters: 208\n" |
| 275 | + ] |
| 276 | + } |
| 277 | + ], |
248 | 278 | "source": [
|
249 | 279 | "network.GlutamateSynapse(\"all\").make_trainable(\"gS\")"
|
250 | 280 | ]
|
|
267 | 297 | },
|
268 | 298 | {
|
269 | 299 | "cell_type": "code",
|
270 |
| - "execution_count": 15, |
| 300 | + "execution_count": 18, |
271 | 301 | "id": "40a48eea",
|
272 | 302 | "metadata": {},
|
273 | 303 | "outputs": [],
|
|
286 | 316 | },
|
287 | 317 | {
|
288 | 318 | "cell_type": "code",
|
289 |
| - "execution_count": 17, |
| 319 | + "execution_count": 19, |
290 | 320 | "id": "4eb3f8f1",
|
291 | 321 | "metadata": {},
|
292 | 322 | "outputs": [],
|
|
312 | 342 | },
|
313 | 343 | {
|
314 | 344 | "cell_type": "code",
|
315 |
| - "execution_count": 24, |
| 345 | + "execution_count": 20, |
316 | 346 | "id": "a29f1ac2",
|
317 | 347 | "metadata": {},
|
318 | 348 | "outputs": [],
|
|
332 | 362 | },
|
333 | 363 | {
|
334 | 364 | "cell_type": "code",
|
335 |
| - "execution_count": 25, |
| 365 | + "execution_count": 21, |
336 | 366 | "id": "f38d61a9",
|
337 | 367 | "metadata": {},
|
338 | 368 | "outputs": [],
|
|
342 | 372 | },
|
343 | 373 | {
|
344 | 374 | "cell_type": "code",
|
345 |
| - "execution_count": 26, |
| 375 | + "execution_count": 22, |
346 | 376 | "id": "9ac97e04",
|
347 | 377 | "metadata": {},
|
348 | 378 | "outputs": [],
|
|
362 | 392 | },
|
363 | 393 | {
|
364 | 394 | "cell_type": "code",
|
365 |
| - "execution_count": 27, |
| 395 | + "execution_count": 23, |
366 | 396 | "id": "d9ccf1b6",
|
367 | 397 | "metadata": {},
|
368 |
| - "outputs": [ |
369 |
| - { |
370 |
| - "ename": "ModuleNotFoundError", |
371 |
| - "evalue": "No module named 'optax'", |
372 |
| - "output_type": "error", |
373 |
| - "traceback": [ |
374 |
| - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
375 |
| - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", |
376 |
| - "\u001b[0;32m/var/folders/j9/w7ftvg_16t1f9bgp1cy4bt1r0000gn/T/ipykernel_6746/3781452166.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0moptax\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", |
377 |
| - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'optax'" |
378 |
| - ] |
379 |
| - } |
380 |
| - ], |
| 398 | + "outputs": [], |
381 | 399 | "source": [
|
382 | 400 | "import optax"
|
383 | 401 | ]
|
|
392 | 410 | },
|
393 | 411 | {
|
394 | 412 | "cell_type": "code",
|
395 |
| - "execution_count": null, |
| 413 | + "execution_count": 24, |
396 | 414 | "id": "710a1545",
|
397 | 415 | "metadata": {},
|
398 | 416 | "outputs": [],
|
|
413 | 431 | },
|
414 | 432 | {
|
415 | 433 | "cell_type": "code",
|
416 |
| - "execution_count": null, |
| 434 | + "execution_count": 25, |
417 | 435 | "id": "800f959e",
|
418 | 436 | "metadata": {},
|
419 | 437 | "outputs": [],
|
|
437 | 455 | },
|
438 | 456 | {
|
439 | 457 | "cell_type": "code",
|
440 |
| - "execution_count": null, |
| 458 | + "execution_count": 26, |
441 | 459 | "id": "9d639efa",
|
442 | 460 | "metadata": {},
|
443 | 461 | "outputs": [],
|
444 | 462 | "source": [
|
445 | 463 | "opt_params = transform.inverse(params)\n",
|
446 |
| - "optimizer = optax.adam(learning_rate=1e-2)\n", |
| 464 | + "optimizer = optax.adam(learning_rate=1e-1)\n", |
447 | 465 | "opt_state = optimizer.init(opt_params)"
|
448 | 466 | ]
|
449 | 467 | },
|
450 | 468 | {
|
451 | 469 | "cell_type": "code",
|
452 |
| - "execution_count": null, |
| 470 | + "execution_count": 27, |
453 | 471 | "id": "0e4aebd0-283e-4165-8c24-4b6fb811135e",
|
454 | 472 | "metadata": {},
|
455 |
| - "outputs": [], |
456 |
| - "source": [ |
457 |
| - "epoch_losses = []\n", |
458 |
| - "\n", |
459 |
| - "for epoch in range(5):\n", |
460 |
| - " loss_val, gradient = jitted_grad(opt_params)\n", |
461 |
| - " updates, opt_state = optimizer.update(gradient, opt_state)\n", |
462 |
| - " opt_params = optax.apply_updates(opt_params, updates)\n", |
463 |
| - "\n", |
464 |
| - " print(f\"epoch {epoch}, loss {loss_val}\")\n", |
465 |
| - " epoch_losses.append(loss_val)\n", |
466 |
| - " \n", |
467 |
| - "final_params = transform.forward(opt_params)" |
468 |
| - ] |
469 |
| - }, |
470 |
| - { |
471 |
| - "cell_type": "code", |
472 |
| - "execution_count": 43, |
473 |
| - "id": "134af3e1", |
474 |
| - "metadata": {}, |
475 | 473 | "outputs": [
|
476 | 474 | {
|
477 | 475 | "name": "stdout",
|
478 | 476 | "output_type": "stream",
|
479 | 477 | "text": [
|
480 |
| - "epoch 0, loss -64.97740510297487\n", |
481 |
| - "epoch 1, loss -64.98296369502924\n", |
482 |
| - "epoch 2, loss -64.98846441030534\n", |
483 |
| - "epoch 3, loss -64.99390672049375\n", |
484 |
| - "epoch 4, loss -64.99929015505666\n" |
| 478 | + "epoch 0, loss -64.97740512171988\n", |
| 479 | + "epoch 1, loss -65.03050878143252\n", |
| 480 | + "epoch 2, loss -65.07820463321355\n", |
| 481 | + "epoch 3, loss -65.12091871011332\n", |
| 482 | + "epoch 4, loss -65.15909222980945\n" |
485 | 483 | ]
|
486 | 484 | }
|
487 | 485 | ],
|
|
510 | 508 | ],
|
511 | 509 | "metadata": {
|
512 | 510 | "kernelspec": {
|
513 |
| - "display_name": "jax", |
| 511 | + "display_name": "neurax", |
514 | 512 | "language": "python",
|
515 |
| - "name": "jax" |
| 513 | + "name": "python3" |
516 | 514 | },
|
517 | 515 | "language_info": {
|
518 | 516 | "codemirror_mode": {
|
|
524 | 522 | "name": "python",
|
525 | 523 | "nbconvert_exporter": "python",
|
526 | 524 | "pygments_lexer": "ipython3",
|
527 |
| - "version": "3.8.12" |
| 525 | + "version": "3.10.11" |
528 | 526 | }
|
529 | 527 | },
|
530 | 528 | "nbformat": 4,
|
|
0 commit comments