-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for graph operations (#355)
* add: add to_graph method * add: add from_graph method * fix: isort and rm obs import * doc: added more comments * wip: started adding new swc import function * wip: save dev wip * add: implement graph2jaxley method, that imposes branch structure on some graph * fix: small fixes * enh: add prelim version of xyzr * wip: save wip * enh: refactor of swc->graph->jaxley->module pipeline * enh: cleanup * fix: small fix, now from graph runs on compartmentalized morphology * doc: add docstrings and type hints * doc: more comments added * fix: small fixes. * fix: rm undefined from groups at import * wip: save wip * wip: save wip * rm: remove dev notebook from tracking * add: add tests for graph functionalities * fix: added more tests and they are now passing * fix: test fixes * wip: save wip working on graph and swc io * fix: radius import fixed * wip: save wip * enh: massive overhaul complete of graph pipeline compared to before. now much simpler and neuron comparison tests pass. * wip: save wip * wip: save wip * fix: some fixes added * wip: save wip * wip: save wip tests * wip: test look better than before. * wip: remove complexity and improve test MSE. still not there though * wip: save wip. * rm: rm notebooks from pr * wip: save wip * enh: small refactor of swc -> initial graph pipe * enh: massive overhaul auf graph pipe. add documentation. passes tests now * wip: start adding tests. * enh: incl graphIO in tutorial * wip: progress on tests and tutorial * wip: tests are passing. except for voltages, which only passes in notebook but not in pytest * rm: dev notebook removed * fix: tests passing for non-single soma morpho * enh: Tests are finally passing * fix: add misssing kwarg in simulate_trace_error * fix: fix diff with main * fix: rm diff in modules/base * fix: add __eq__ back in for comparisons of cells attr in net and fix asteric * chore: ran black * fix: change read_swc to io imports * chore: add license header * enh: fixup of wording * fix: fix merge artefact * rm: rm tutorials left from rebase * wip: swc pipe now works up until module import * wip: in/out pipeline working for morphology, but w.o. attrs like recordings. * fix: update tutorial * fix: fix tutorial and synapses and input output * fix: add networkx as dep * fix rebase io/swc.py to main * doc: add more function documentation * wip: step 1 on getting tests to pass * fix: reduce diff * fix: ammend last commit * fix: pop l when converting graph * wip: working on tests * wip: all but one test passing, working on import export cycle validation * fix: finished import export tests and all tests are passing * fix: speed up tests and add docs * chore: edit changelog * fix: fix import cycle error, add min_radius and make some functions private * add: add new test morphology * fix: refactor swc and combine both swc_readers into a combined method * wip: change root finder and fix testcase * fix: made more methods private, fix issues with 0 length edges in graph * chore: edit changelog * chore: rfmt changelog * fix: skip new morph test for now
Showing
14 changed files
with
12,871 additions
and
292 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is | ||
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/> | ||
|
||
from jaxley.io.graph import from_graph, to_graph | ||
from jaxley.io.swc import read_swc |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ dependencies = [ | |
"numpy", | ||
"pandas>=2.2.0", | ||
"tridiax", | ||
"networkx", | ||
] | ||
|
||
[project.optional-dependencies] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is | ||
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/> | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
def get_segment_xyzrL(section, comp_idx=None, loc=None, nseg=8): | ||
assert ( | ||
comp_idx is not None or loc is not None | ||
), "Either comp_idx or loc must be provided." | ||
assert not ( | ||
comp_idx is not None and loc is not None | ||
), "Only one of comp_idx or loc can be provided." | ||
|
||
comp_len = 1 / nseg | ||
loc = comp_len / 2 + comp_idx * comp_len if loc is None else loc | ||
|
||
n3d = section.n3d() | ||
x3d = np.array([section.x3d(i) for i in range(n3d)]) | ||
y3d = np.array([section.y3d(i) for i in range(n3d)]) | ||
z3d = np.array([section.z3d(i) for i in range(n3d)]) | ||
L = np.array([section.arc3d(i) for i in range(n3d)]) # Cumulative arc lengths | ||
r3d = np.array([section.diam3d(i) / 2 for i in range(n3d)]) | ||
if loc is None: | ||
return x3d, y3d, z3d, r3d | ||
else: | ||
total_length = L[-1] | ||
target_length = loc * total_length | ||
|
||
# Find segment containing target_length | ||
for i in range(1, n3d): | ||
if L[i] >= target_length: | ||
break | ||
else: | ||
i = n3d - 1 | ||
|
||
# Interpolate between points i-1 and i | ||
L0, L1 = L[i - 1], L[i] | ||
t = (target_length - L0) / (L1 - L0) | ||
x = x3d[i - 1] + t * (x3d[i] - x3d[i - 1]) | ||
y = y3d[i - 1] + t * (y3d[i] - y3d[i - 1]) | ||
z = z3d[i - 1] + t * (z3d[i] - z3d[i - 1]) | ||
r = r3d[i - 1] + t * (r3d[i] - r3d[i - 1]) | ||
return x, y, z, r, L[-1] / nseg | ||
|
||
|
||
def jaxley2neuron_by_coords(jx_cell, neuron_secs, comp_idx=None, loc=None, nseg=8): | ||
neuron_coords = { | ||
i: np.vstack(get_segment_xyzrL(sec, comp_idx=comp_idx, loc=loc, nseg=nseg))[ | ||
:3 | ||
].T | ||
for i, sec in enumerate(neuron_secs) | ||
} | ||
neuron_coords = np.vstack( | ||
[np.hstack([k * np.ones((v.shape[0], 1)), v]) for k, v in neuron_coords.items()] | ||
) | ||
neuron_coords = pd.DataFrame( | ||
neuron_coords, columns=["global_branch_index", "x", "y", "z"] | ||
) | ||
neuron_coords["global_branch_index"] = neuron_coords["global_branch_index"].astype( | ||
int | ||
) | ||
|
||
neuron_loc_xyz = neuron_coords.groupby("global_branch_index").mean() | ||
jaxley_loc_xyz = ( | ||
jx_cell.branch("all") | ||
.loc(loc) | ||
.nodes.set_index("global_branch_index")[["x", "y", "z"]] | ||
) | ||
|
||
jaxley2neuron_inds = {} | ||
for i, xyz in enumerate(jaxley_loc_xyz.to_numpy()): | ||
d = np.sqrt(((neuron_loc_xyz - xyz) ** 2)).sum(axis=1) | ||
jaxley2neuron_inds[i] = d.argmin() | ||
return jaxley2neuron_inds | ||
|
||
|
||
def jaxley2neuron_by_group( | ||
jx_cell, | ||
neuron_secs, | ||
comp_idx=None, | ||
loc=None, | ||
nseg=8, | ||
num_apical=20, | ||
num_tuft=20, | ||
num_basal=10, | ||
): | ||
y_apical = ( | ||
jx_cell.apical.nodes.groupby("global_branch_index") | ||
.mean()["y"] | ||
.abs() | ||
.sort_values() | ||
) | ||
trunk_inds = y_apical.index[:num_apical].tolist() | ||
tuft_inds = y_apical.index[-num_tuft:].tolist() | ||
basal_inds = ( | ||
jx_cell.basal.nodes["global_branch_index"].unique()[:num_basal].tolist() | ||
) | ||
|
||
jaxley2neuron = jaxley2neuron_by_coords( | ||
jx_cell, neuron_secs, comp_idx=comp_idx, loc=loc, nseg=nseg | ||
) | ||
|
||
neuron_trunk_inds = [jaxley2neuron[i] for i in trunk_inds] | ||
neuron_tuft_inds = [jaxley2neuron[i] for i in tuft_inds] | ||
neuron_basal_inds = [jaxley2neuron[i] for i in basal_inds] | ||
|
||
neuron_inds = { | ||
"trunk": neuron_trunk_inds, | ||
"tuft": neuron_tuft_inds, | ||
"basal": neuron_basal_inds, | ||
} | ||
jaxley_inds = {"trunk": trunk_inds, "tuft": tuft_inds, "basal": basal_inds} | ||
return neuron_inds, jaxley_inds | ||
|
||
|
||
def match_stim_loc(jx_cell, neuron_sec, comp_idx=None, loc=None, nseg=8): | ||
stim_coords = get_segment_xyzrL(neuron_sec, comp_idx=comp_idx, loc=loc, nseg=nseg)[ | ||
:3 | ||
] | ||
stim_idx = ( | ||
((jx_cell.nodes[["x", "y", "z"]] - stim_coords) ** 2).sum(axis=1).argmin() | ||
) | ||
return stim_idx | ||
|
||
|
||
def import_neuron_morph(fname, nseg=8): | ||
from neuron import h | ||
|
||
_ = h.load_file("stdlib.hoc") | ||
_ = h.load_file("import3d.hoc") | ||
nseg = 8 | ||
|
||
##################### NEURON ################## | ||
for sec in h.allsec(): | ||
h.delete_section(sec=sec) | ||
|
||
cell = h.Import3d_SWC_read() | ||
cell.input(fname) | ||
i3d = h.Import3d_GUI(cell, False) | ||
i3d.instantiate(None) | ||
|
||
for sec in h.allsec(): | ||
sec.nseg = nseg | ||
return h, cell | ||
|
||
|
||
def equal_both_nan_or_empty_df(a, b): | ||
if a.empty and b.empty: | ||
return True | ||
a[a.isna()] = -1 | ||
b[b.isna()] = -1 | ||
if set(a.columns) != set(b.columns): | ||
return False | ||
else: | ||
a = a[b.columns] | ||
return (a == b).all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,378 @@ | ||
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is | ||
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/> | ||
|
||
import os | ||
from copy import deepcopy | ||
|
||
import jax | ||
|
||
jax.config.update("jax_enable_x64", True) | ||
jax.config.update("jax_platform_name", "cpu") | ||
|
||
|
||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8" | ||
|
||
import jax.numpy as jnp | ||
import networkx as nx | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
import jaxley as jx | ||
from jaxley import connect | ||
from jaxley.channels import HH | ||
from jaxley.channels.pospischil import K, Leak, Na | ||
from jaxley.io.graph import ( | ||
add_missing_graph_attrs, | ||
from_graph, | ||
make_jaxley_compatible, | ||
swc_to_graph, | ||
to_graph, | ||
trace_branches, | ||
) | ||
from jaxley.synapses import IonotropicSynapse, TestSynapse | ||
|
||
# from jaxley.utils.misc_utils import recursive_compare | ||
from tests.helpers import ( | ||
equal_both_nan_or_empty_df, | ||
get_segment_xyzrL, | ||
import_neuron_morph, | ||
jaxley2neuron_by_group, | ||
match_stim_loc, | ||
) | ||
|
||
|
||
# test exporting and re-importing of different modules | ||
def test_graph_import_export_cycle( | ||
SimpleComp, SimpleBranch, SimpleCell, SimpleNet, SimpleMorphCell | ||
): | ||
np.random.seed(0) | ||
comp = SimpleComp() | ||
branch = SimpleBranch(4) | ||
cell = SimpleCell(5, 4) | ||
morph_cell = SimpleMorphCell(ncomp=1) | ||
net = SimpleNet(3, 5, 4) | ||
|
||
# add synapses | ||
connect(net[0, 0, 0], net[1, 0, 0], IonotropicSynapse()) | ||
connect(net[0, 0, 1], net[1, 0, 1], IonotropicSynapse()) | ||
# connect(net[0, 0, 1], net[1, 0, 1], TestSynapse()) # makes test fail, see warning w. synapses = True | ||
|
||
# add groups | ||
net.cell(2).add_to_group("cell2") | ||
net.cell(2).branch(1).add_to_group("cell2branch1") | ||
|
||
# add ion channels | ||
net.cell(0).insert(Na()) | ||
net.cell(0).insert(Leak()) | ||
net.cell(1).branch(1).insert(Na()) | ||
net.cell(0).insert(K()) | ||
|
||
# test consistency of exported and re-imported modules | ||
for module in [comp, branch, cell, net, morph_cell]: | ||
module.compute_xyz() # ensure x,y,z in nodes b4 exporting for later comparison | ||
module_graph = to_graph( | ||
module, channels=True, synapses=True | ||
) # ensure to_graph works | ||
re_module = from_graph(module_graph) # ensure prev exported graph can be read | ||
re_module_graph = to_graph( | ||
re_module, channels=True, synapses=True | ||
) # ensure to_graph works for re-imported modules | ||
|
||
# ensure original module and re-imported module are equal | ||
assert np.all(equal_both_nan_or_empty_df(re_module.nodes, module.nodes)) | ||
assert np.all(equal_both_nan_or_empty_df(re_module.edges, module.edges)) | ||
assert np.all( | ||
equal_both_nan_or_empty_df(re_module.branch_edges, module.branch_edges) | ||
) | ||
|
||
for k in module.groups: | ||
assert k in re_module.groups | ||
assert np.all(re_module.groups[k] == module.groups[k]) | ||
|
||
for re_xyzr, xyzr in zip(re_module.xyzr, module.xyzr): | ||
re_xyzr[np.isnan(re_xyzr)] = -1 | ||
xyzr[np.isnan(xyzr)] = -1 | ||
|
||
assert np.all(re_xyzr == xyzr) | ||
|
||
re_imported_mechs = re_module.channels + re_module.synapses | ||
for re_mech, mech in zip(re_imported_mechs, module.channels + module.synapses): | ||
assert np.all(re_mech.name == mech.name) | ||
|
||
# ensure exported graph and re-exported graph are equal | ||
node_df = pd.DataFrame( | ||
[d for i, d in module_graph.nodes(data=True)], index=module_graph.nodes | ||
).sort_index() | ||
re_node_df = pd.DataFrame( | ||
[d for i, d in re_module_graph.nodes(data=True)], | ||
index=re_module_graph.nodes, | ||
).sort_index() | ||
assert np.all(equal_both_nan_or_empty_df(node_df, re_node_df)) | ||
|
||
edges = pd.DataFrame( | ||
[ | ||
{ | ||
"pre_global_comp_index": i, | ||
"post_global_comp_index": j, | ||
**module_graph.edges[i, j], | ||
} | ||
for (i, j) in module_graph.edges | ||
] | ||
) | ||
re_edges = pd.DataFrame( | ||
[ | ||
{ | ||
"pre_global_comp_index": i, | ||
"post_global_comp_index": j, | ||
**re_module_graph.edges[i, j], | ||
} | ||
for (i, j) in re_module_graph.edges | ||
] | ||
) | ||
assert np.all(equal_both_nan_or_empty_df(edges, re_edges)) | ||
|
||
# ignore "externals", "recordings", "trainable_params", "indices_set_by_trainables" | ||
for k in ["ncomp", "xyzr"]: | ||
assert module_graph.graph[k] == re_module_graph.graph[k] | ||
|
||
# assume if module can be integrated, so can be comp, cell and branch | ||
if isinstance(module, jx.Network): | ||
# test integration of re-imported module | ||
re_module.select(nodes=0).record(verbose=False) | ||
jx.integrate(re_module, t_max=0.5) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"file", ["morph_single_point_soma.swc", "morph.swc", "bbp_with_axon.swc"] | ||
) | ||
def test_trace_branches(file): | ||
dirname = os.path.dirname(__file__) | ||
fname = os.path.join(dirname, "swc_files", file) | ||
graph = swc_to_graph(fname) | ||
|
||
# pre-processing | ||
graph = add_missing_graph_attrs(graph) | ||
graph = trace_branches(graph, None, ignore_swc_trace_errors=False) | ||
|
||
edges = pd.DataFrame([{"u": u, "v": v, **d} for u, v, d in graph.edges(data=True)]) | ||
nx_branch_lens = edges.groupby("branch_index")["l"].sum().to_numpy() | ||
nx_branch_lens = np.sort(nx_branch_lens) | ||
|
||
# exclude artificial root branch | ||
if np.isclose(nx_branch_lens[0], 1e-1): | ||
nx_branch_lens = nx_branch_lens[1:] | ||
|
||
h, _ = import_neuron_morph(fname) | ||
neuron_branch_lens = np.sort([sec.L for sec in h.allsec()]) | ||
|
||
errors = np.abs(neuron_branch_lens - nx_branch_lens) | ||
# one error is expected, see https://github.com/jaxleyverse/jaxley/issues/140 | ||
assert sum(errors > 1e-3) <= 1 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"file", ["morph_single_point_soma.swc", "morph.swc", "bbp_with_axon.swc"] | ||
) | ||
def test_from_graph_vs_NEURON(file): | ||
ncomp = 8 | ||
dirname = os.path.dirname(__file__) | ||
fname = os.path.join(dirname, "swc_files", file) | ||
|
||
graph = swc_to_graph(fname) | ||
cell = from_graph( | ||
graph, ncomp=ncomp, max_branch_len=2000, ignore_swc_trace_errors=False | ||
) | ||
cell.compute_compartment_centers() | ||
h, neuron_cell = import_neuron_morph(fname, nseg=ncomp) | ||
|
||
# remove root branch | ||
jaxley_comps = cell.nodes[ | ||
~np.isclose(cell.nodes["length"], 0.1 / ncomp) | ||
].reset_index(drop=True) | ||
|
||
jx_branch_lens = ( | ||
jaxley_comps.groupby("global_branch_index")["length"].sum().to_numpy() | ||
) | ||
|
||
# match by branch lengths | ||
neuron_xyzd = [np.array(s.psection()["morphology"]["pts3d"]) for s in h.allsec()] | ||
neuron_branch_lens = np.array( | ||
[ | ||
np.sqrt((np.diff(n[:, :3], axis=0) ** 2).sum(axis=1)).sum() | ||
for n in neuron_xyzd | ||
] | ||
) | ||
neuron_inds = np.argsort(neuron_branch_lens) | ||
jx_inds = np.argsort(jx_branch_lens) | ||
|
||
neuron_df = pd.DataFrame(columns=["neuron_idx", "x", "y", "z", "radius", "length"]) | ||
jx_df = pd.DataFrame(columns=["jx_idx", "x", "y", "z", "radius", "length"]) | ||
for k in range(len(neuron_inds)): | ||
neuron_comp_k = np.array( | ||
[ | ||
get_segment_xyzrL(list(h.allsec())[neuron_inds[k]], comp_idx=i) | ||
for i in range(ncomp) | ||
] | ||
) | ||
# make this a dataframe | ||
neuron_comp_k = pd.DataFrame( | ||
neuron_comp_k, columns=["x", "y", "z", "radius", "length"] | ||
) | ||
neuron_comp_k["idx"] = neuron_inds[k] | ||
jx_comp_k = jaxley_comps[jaxley_comps["global_branch_index"] == jx_inds[k]][ | ||
["x", "y", "z", "radius", "length"] | ||
] | ||
jx_comp_k["idx"] = jx_inds[k] | ||
neuron_df = pd.concat([neuron_df, neuron_comp_k], axis=0, ignore_index=True) | ||
jx_df = pd.concat([jx_df, jx_comp_k], axis=0, ignore_index=True) | ||
|
||
errors = neuron_df["neuron_idx"].to_frame() | ||
errors["jx_idx"] = jx_df["jx_idx"] | ||
errors[["x", "y", "z"]] = neuron_df[["x", "y", "z"]] - jx_df[["x", "y", "z"]] | ||
errors["xyz"] = np.sqrt((errors[["x", "y", "z"]] ** 2).sum(axis=1)) | ||
errors["radius"] = neuron_df["radius"] - jx_df["radius"] | ||
errors["length"] = neuron_df["length"] - jx_df["length"] | ||
|
||
# one error is expected, see https://github.com/jaxleyverse/jaxley/issues/140 | ||
assert sum(errors.groupby("jx_idx")["xyz"].max() > 1e-3) <= 1 | ||
assert sum(errors.groupby("jx_idx")["radius"].max() > 1e-3) <= 1 | ||
assert sum(errors.groupby("jx_idx")["length"].max() > 1e-3) <= 1 | ||
|
||
|
||
def test_edges_only_to_jaxley(): | ||
# test if edge graph can pe imported into to jaxley | ||
sets_of_edges = [ | ||
[(0, 1), (1, 2), (2, 3)], | ||
[(0, 1), (1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)], | ||
] | ||
for edges in sets_of_edges: | ||
edge_graph = nx.DiGraph(edges) | ||
edge_module = from_graph(edge_graph) | ||
|
||
|
||
@pytest.mark.slow | ||
@pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"]) | ||
def test_swc2graph_voltages(file): | ||
"""Check if voltages of SWC recording match. | ||
To match the branch indices between NEURON and jaxley, we rely on comparing the | ||
length of the branches. | ||
It tests whether, on average over time and recordings, the voltage is off by less | ||
than 1.5 mV. | ||
""" | ||
dirname = os.path.dirname(__file__) | ||
fname = os.path.join(dirname, "swc_files", file) # n120 | ||
|
||
ncomp = 8 | ||
|
||
i_delay = 2.0 | ||
i_dur = 5.0 | ||
i_amp = 0.25 | ||
t_max = 20.0 | ||
dt = 0.025 | ||
|
||
##################### NEURON ################## | ||
h, neuron_cell = import_neuron_morph(fname, nseg=ncomp) | ||
|
||
####################### jaxley ################## | ||
graph = swc_to_graph(fname) | ||
jx_cell = from_graph( | ||
graph, ncomp=ncomp, max_branch_len=2000, ignore_swc_trace_errors=False | ||
) | ||
jx_cell.compute_compartment_centers() | ||
jx_cell.insert(HH()) | ||
|
||
branch_loc = 0.05 | ||
neuron_inds, jaxley_inds = jaxley2neuron_by_group( | ||
jx_cell, h.allsec(), loc=branch_loc | ||
) | ||
trunk_inds, tuft_inds, basal_inds = [ | ||
jaxley_inds[key] for key in ["trunk", "tuft", "basal"] | ||
] | ||
neuron_trunk_inds, neuron_tuft_inds, neuron_basal_inds = [ | ||
neuron_inds[key] for key in ["trunk", "tuft", "basal"] | ||
] | ||
|
||
stim_loc = 0.1 | ||
stim_idx = match_stim_loc(jx_cell, h.soma[0], loc=stim_loc) | ||
|
||
jx_cell.set("axial_resistivity", 1_000.0) | ||
jx_cell.set("v", -62.0) | ||
jx_cell.set("HH_m", 0.074901) | ||
jx_cell.set("HH_h", 0.4889) | ||
jx_cell.set("HH_n", 0.3644787) | ||
|
||
jx_cell.select(stim_idx).stimulate( | ||
jx.step_current(i_delay, i_dur, i_amp, dt, t_max) | ||
) | ||
for i in trunk_inds + tuft_inds + basal_inds: | ||
jx_cell.branch(i).loc(branch_loc).record() | ||
|
||
voltages_jaxley = jx.integrate(jx_cell, delta_t=dt) | ||
|
||
################### NEURON ################# | ||
stim = h.IClamp(h.soma[0](stim_loc)) | ||
stim.delay = i_delay | ||
stim.dur = i_dur | ||
stim.amp = i_amp | ||
|
||
counter = 0 | ||
voltage_recs = {} | ||
|
||
for r in neuron_trunk_inds: | ||
for i, sec in enumerate(h.allsec()): | ||
if i == r: | ||
v = h.Vector() | ||
v.record(sec(branch_loc)._ref_v) | ||
voltage_recs[f"v{counter}"] = v | ||
counter += 1 | ||
|
||
for r in neuron_tuft_inds: | ||
for i, sec in enumerate(h.allsec()): | ||
if i == r: | ||
v = h.Vector() | ||
v.record(sec(branch_loc)._ref_v) | ||
voltage_recs[f"v{counter}"] = v | ||
counter += 1 | ||
|
||
for r in neuron_basal_inds: | ||
for i, sec in enumerate(h.allsec()): | ||
if i == r: | ||
v = h.Vector() | ||
v.record(sec(branch_loc)._ref_v) | ||
voltage_recs[f"v{counter}"] = v | ||
counter += 1 | ||
|
||
for sec in h.allsec(): | ||
sec.insert("hh") | ||
sec.Ra = 1_000.0 | ||
|
||
sec.gnabar_hh = 0.120 # S/cm2 | ||
sec.gkbar_hh = 0.036 # S/cm2 | ||
sec.gl_hh = 0.0003 # S/cm2 | ||
sec.ena = 50 # mV | ||
sec.ek = -77.0 # mV | ||
sec.el_hh = -54.3 # mV | ||
|
||
h.dt = dt | ||
tstop = t_max | ||
v_init = -62.0 | ||
|
||
def initialize(): | ||
h.finitialize(v_init) | ||
h.fcurrent() | ||
|
||
def integrate(): | ||
while h.t < tstop: | ||
h.fadvance() | ||
|
||
initialize() | ||
integrate() | ||
voltages_neuron = np.asarray([voltage_recs[key] for key in voltage_recs]) | ||
|
||
####################### check ################ | ||
errors = np.mean(np.abs(voltages_jaxley - voltages_neuron), axis=1) | ||
|
||
assert all(errors < 2.5), "voltages do not match." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters