Skip to content

Commit

Permalink
Speed up tests (swc, plotting) (#479)
Browse files Browse the repository at this point in the history
* fix: speed up swc tests (change to sparse solver)

* create fixtures for plotting tests, split volume plotting in plural tests

* expose resolution; speed up plotting tests #468, #449
  • Loading branch information
fabioseel authored Nov 14, 2024
1 parent ad6ce01 commit 74f5994
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 69 deletions.
48 changes: 38 additions & 10 deletions jaxley/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def create_cone_frustum_mesh(
radius_top: float,
bottom_dome: bool = False,
top_dome: bool = False,
resolution: int = 100,
) -> ndarray:
"""Generates mesh points for a cone frustum, with optional domes at either end.
Expand All @@ -120,12 +121,14 @@ def create_cone_frustum_mesh(
The dome is a hemisphere with radius `radius_bottom`.
top_dome: If True, a dome is added to the top of the frustum.
The dome is a hemisphere with radius `radius_top`.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.
Returns:
An array of mesh points.
"""

resolution = 100
t = np.linspace(0, 2 * np.pi, resolution)

# Determine the total height including domes
Expand Down Expand Up @@ -175,7 +178,9 @@ def create_cone_frustum_mesh(
return np.stack([x_coords, y_coords, z_coords])


def create_cylinder_mesh(length: float, radius: float) -> ndarray:
def create_cylinder_mesh(
length: float, radius: float, resolution: int = 100
) -> ndarray:
"""Generates mesh points for a cylinder.
This is used to render cylindrical compartments in 3D (and to project it to 2D)
Expand All @@ -184,12 +189,14 @@ def create_cylinder_mesh(length: float, radius: float) -> ndarray:
Args:
length: The length of the cylinder.
radius: The radius of the cylinder.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.
Returns:
An array of mesh points.
"""
# Define cylinder
resolution = 100
t = np.linspace(0, 2 * np.pi, resolution)
z_coords = np.linspace(-length / 2, length / 2, resolution)
t_grid, z_coords = np.meshgrid(t, z_coords)
Expand All @@ -199,19 +206,21 @@ def create_cylinder_mesh(length: float, radius: float) -> ndarray:
return np.stack([x_coords, y_coords, z_coords])


def create_sphere_mesh(radius: float) -> np.ndarray:
def create_sphere_mesh(radius: float, resolution: int = 100) -> np.ndarray:
"""Generates mesh points for a sphere.
This is used to render spherical compartments in 3D (and to project it to 2D)
as part of `plot_comps`.
Args:
radius: The radius of the sphere.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.
Returns:
An array of mesh points.
"""
resolution = 100
phi = np.linspace(0, np.pi, resolution)
theta = np.linspace(0, 2 * np.pi, resolution)

Expand Down Expand Up @@ -302,8 +311,9 @@ def plot_comps(
ax: Optional[Axes] = None,
comp_plot_kwargs: Dict = {},
true_comp_length: bool = True,
resolution: int = 100,
) -> Axes:
"""Plot compartmentalized neural mrophology.
"""Plot compartmentalized neural morphology.
Plots the projection of the cylindrical compartments.
Expand All @@ -320,6 +330,9 @@ def plot_comps(
start and end point of the neurite. This can lead to overlapping and
miss-aligned cylinders. Setting this False will use the straight-line
distance instead for nicer plots.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.
Returns:
Plot of the compartmentalized morphology.
Expand All @@ -340,7 +353,7 @@ def plot_comps(
radius = xyzr[:, -1]
center = xyzr[0, :3]
if len(dims) == 3:
xyz = create_sphere_mesh(radius)
xyz = create_sphere_mesh(radius, resolution)
ax = plot_mesh(
xyz,
np.array([0, 0, 1]),
Expand Down Expand Up @@ -368,7 +381,7 @@ def plot_comps(
center = comp[["x", "y", "z"]]
radius = comp["radius"]
length = comp["length"] if true_comp_length else l
xyz = create_cylinder_mesh(length, radius)
xyz = create_cylinder_mesh(length, radius, resolution)
ax = plot_mesh(
xyz,
axis,
Expand All @@ -386,6 +399,7 @@ def plot_morph(
dims: Tuple[int] = (0, 1),
col: str = "k",
ax: Optional[Axes] = None,
resolution: int = 100,
morph_plot_kwargs: Dict = {},
) -> Axes:
"""Plot the detailed morphology.
Expand All @@ -404,6 +418,10 @@ def plot_morph(
ax: The matplotlib axis to plot on.
morph_plot_kwargs: The plot kwargs for plt.fill.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.
Returns:
Plot of the detailed morphology."""
if ax is None:
Expand All @@ -424,7 +442,12 @@ def plot_morph(
dxyz = xyzr2[:3] - xyzr1[:3]
length = np.sqrt(np.sum(dxyz**2))
points = create_cone_frustum_mesh(
length, xyzr1[-1], xyzr2[-1], bottom_dome=True, top_dome=True
length,
xyzr1[-1],
xyzr2[-1],
bottom_dome=True,
top_dome=True,
resolution=resolution,
)
plot_mesh(
points,
Expand All @@ -437,7 +460,12 @@ def plot_morph(
)
else:
points = create_cone_frustum_mesh(
0, xyzr[:, -1], xyzr[:, -1], bottom_dome=True, top_dome=True
0,
xyzr[:, -1],
xyzr[:, -1],
bottom_dome=True,
top_dome=True,
resolution=resolution,
)
plot_mesh(
points,
Expand Down
119 changes: 61 additions & 58 deletions tests/test_plotting_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,58 @@
from jaxley.synapses import IonotropicSynapse


def test_cell():
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", "morph.swc")
cell = jx.read_swc(fname, nseg=4)
@pytest.fixture(scope="module")
def comp() -> jx.Compartment:
comp = jx.Compartment()
comp.compute_xyz()
return comp


@pytest.fixture(scope="module")
def branch(comp) -> jx.Branch:
branch = jx.Branch(comp, 4)
branch.compute_xyz()
return branch


@pytest.fixture(scope="module")
def cell(branch) -> jx.Cell:
cell = jx.Cell(branch, [-1, 0, 0, 1, 1])
cell.compute_xyz()
return cell


@pytest.fixture(scope="module")
def simple_net(cell) -> jx.Network:
net = jx.Network([cell] * 4)
net.compute_xyz()
return net


@pytest.fixture(scope="module")
def morph_cell() -> jx.Cell:
morph_cell = jx.read_swc(
os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"),
nseg=1,
)
return morph_cell


def test_cell(morph_cell):
# Plot 1.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
ax = cell.vis(ax=ax)
ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r")
ax = cell.branch(1).loc(0.9).vis(ax=ax, col="b")
ax = morph_cell.vis(ax=ax)
ax = morph_cell.branch([0, 1, 2]).vis(ax=ax, col="r")
ax = morph_cell.branch(1).loc(0.9).vis(ax=ax, col="b")

# Plot 2.
cell.branch(0).add_to_group("soma")
cell.branch(1).add_to_group("soma")
ax = cell.soma.vis()

morph_cell.branch(0).add_to_group("soma")
morph_cell.branch(1).add_to_group("soma")
ax = morph_cell.soma.vis()

def test_network():
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", "morph.swc")
cell1 = jx.read_swc(fname, nseg=4)
cell2 = jx.read_swc(fname, nseg=4)
cell3 = jx.read_swc(fname, nseg=4)

net = jx.Network([cell1, cell2, cell3])
def test_network(morph_cell):
net = jx.Network([morph_cell, morph_cell, morph_cell])
connect(
net.cell(0).branch(0).loc(0.0),
net.cell(1).branch(0).loc(0.0),
Expand Down Expand Up @@ -81,11 +108,7 @@ def test_network():
ax = net.excitatory.vis()


def test_vis_networks_built_from_scartch():
comp = jx.Compartment()
branch = jx.Branch(comp, 4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])

def test_vis_networks_built_from_scratch(comp, branch, cell):
net = jx.Network([cell, cell])
connect(
net.cell(0).branch(0).loc(0.0),
Expand All @@ -110,25 +133,15 @@ def test_vis_networks_built_from_scartch():

# Plot 3.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
comp.compute_xyz()
ax = comp.vis(ax=ax)

# Plot 4.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
branch.compute_xyz()
ax = branch.vis(ax=ax)


def test_mixed_network():
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", "morph.swc")
cell1 = jx.read_swc(fname, nseg=4)

comp = jx.Compartment()
branch = jx.Branch(comp, 4)
cell2 = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])

net = jx.Network([cell1, cell2])
def test_mixed_network(morph_cell, cell):
net = jx.Network([morph_cell, cell])
connect(
net.cell(0).branch(0).loc(0.0),
net.cell(1).branch(0).loc(0.0),
Expand All @@ -145,9 +158,9 @@ def test_mixed_network():
net.cell(1).move(0, -800)
net.rotate(180)

before_xyzrs = deepcopy(net.xyzr[len(cell1.xyzr) :])
before_xyzrs = deepcopy(net.xyzr[len(morph_cell.xyzr) :])
net.cell(1).rotate(90)
after_xyzrs = net.xyzr[len(cell1.xyzr) :]
after_xyzrs = net.xyzr[len(morph_cell.xyzr) :]
# Test that rotation worked as expected.
for b, a in zip(before_xyzrs, after_xyzrs):
assert np.allclose(b[:, 0], -a[:, 1], atol=1e-6)
Expand All @@ -156,34 +169,24 @@ def test_mixed_network():
_ = net.vis(detail="full")


def test_volume_plotting():
comp = jx.Compartment()
comp.compute_xyz()
branch = jx.Branch(comp, 4)
branch.compute_xyz()
cell = jx.Cell([branch] * 3, [-1, 0, 0])
cell.compute_xyz()
net = jx.Network([cell] * 4)
net.compute_xyz()

morph_cell = jx.read_swc(
os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"),
nseg=1,
)

def test_volume_plotting_2d(comp, branch, cell, simple_net, morph_cell):
fig, ax = plt.subplots()
for module in [comp, branch, cell, net, morph_cell]:
module.vis(type="comp", ax=ax)
for module in [comp, branch, cell, simple_net, morph_cell]:
module.vis(type="comp", ax=ax, morph_plot_kwargs={"resolution": 6})
plt.close(fig)


def test_volume_plotting_3d(comp, branch, cell, simple_net, morph_cell):
# test 3D plotting
for module in [comp, branch, cell, net, morph_cell]:
module.vis(type="comp", dims=[0, 1, 2])
for module in [comp, branch, cell, simple_net, morph_cell]:
module.vis(type="comp", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6})
plt.close()


def test_morph_plotting(morph_cell):
# test morph plotting (does not work if no radii in xyzr)
morph_cell.vis(type="morph")
morph_cell.vis(type="morph", morph_plot_kwargs={"resolution": 6})
morph_cell.branch(1).vis(
type="morph", dims=[0, 1, 2]
type="morph", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6}
) # plotting whole thing takes too long
plt.close()
2 changes: 1 addition & 1 deletion tests/test_swc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_swc_voltages(file):
for i in trunk_inds + tuft_inds + basal_inds:
cell.branch(i).loc(0.05).record()

voltages_jaxley = jx.integrate(cell, delta_t=dt)
voltages_jaxley = jx.integrate(cell, delta_t=dt, voltage_solver="jax.sparse")

################### NEURON #################
stim = h.IClamp(h.soma[0](0.1))
Expand Down

0 comments on commit 74f5994

Please sign in to comment.