Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up tests (swc, plotting) #479

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions jaxley/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def create_cone_frustum_mesh(
radius_top: float,
bottom_dome: bool = False,
top_dome: bool = False,
resolution: int = 100,
) -> ndarray:
"""Generates mesh points for a cone frustum, with optional domes at either end.

Expand All @@ -120,12 +121,14 @@ def create_cone_frustum_mesh(
The dome is a hemisphere with radius `radius_bottom`.
top_dome: If True, a dome is added to the top of the frustum.
The dome is a hemisphere with radius `radius_top`.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.

Returns:
An array of mesh points.
"""

resolution = 100
t = np.linspace(0, 2 * np.pi, resolution)

# Determine the total height including domes
Expand Down Expand Up @@ -175,7 +178,9 @@ def create_cone_frustum_mesh(
return np.stack([x_coords, y_coords, z_coords])


def create_cylinder_mesh(length: float, radius: float) -> ndarray:
def create_cylinder_mesh(
length: float, radius: float, resolution: int = 100
) -> ndarray:
"""Generates mesh points for a cylinder.

This is used to render cylindrical compartments in 3D (and to project it to 2D)
Expand All @@ -184,12 +189,14 @@ def create_cylinder_mesh(length: float, radius: float) -> ndarray:
Args:
length: The length of the cylinder.
radius: The radius of the cylinder.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.

Returns:
An array of mesh points.
"""
# Define cylinder
resolution = 100
t = np.linspace(0, 2 * np.pi, resolution)
z_coords = np.linspace(-length / 2, length / 2, resolution)
t_grid, z_coords = np.meshgrid(t, z_coords)
Expand All @@ -199,19 +206,21 @@ def create_cylinder_mesh(length: float, radius: float) -> ndarray:
return np.stack([x_coords, y_coords, z_coords])


def create_sphere_mesh(radius: float) -> np.ndarray:
def create_sphere_mesh(radius: float, resolution: int = 100) -> np.ndarray:
"""Generates mesh points for a sphere.

This is used to render spherical compartments in 3D (and to project it to 2D)
as part of `plot_comps`.

Args:
radius: The radius of the sphere.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.

Returns:
An array of mesh points.
"""
resolution = 100
phi = np.linspace(0, np.pi, resolution)
theta = np.linspace(0, 2 * np.pi, resolution)

Expand Down Expand Up @@ -302,8 +311,9 @@ def plot_comps(
ax: Optional[Axes] = None,
comp_plot_kwargs: Dict = {},
true_comp_length: bool = True,
resolution: int = 100,
) -> Axes:
"""Plot compartmentalized neural mrophology.
"""Plot compartmentalized neural morphology.

Plots the projection of the cylindrical compartments.

Expand All @@ -320,6 +330,9 @@ def plot_comps(
start and end point of the neurite. This can lead to overlapping and
miss-aligned cylinders. Setting this False will use the straight-line
distance instead for nicer plots.
resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.

Returns:
Plot of the compartmentalized morphology.
Expand All @@ -340,7 +353,7 @@ def plot_comps(
radius = xyzr[:, -1]
center = xyzr[0, :3]
if len(dims) == 3:
xyz = create_sphere_mesh(radius)
xyz = create_sphere_mesh(radius, resolution)
ax = plot_mesh(
xyz,
np.array([0, 0, 1]),
Expand Down Expand Up @@ -368,7 +381,7 @@ def plot_comps(
center = comp[["x", "y", "z"]]
radius = comp["radius"]
length = comp["length"] if true_comp_length else l
xyz = create_cylinder_mesh(length, radius)
xyz = create_cylinder_mesh(length, radius, resolution)
ax = plot_mesh(
xyz,
axis,
Expand All @@ -386,6 +399,7 @@ def plot_morph(
dims: Tuple[int] = (0, 1),
col: str = "k",
ax: Optional[Axes] = None,
resolution: int = 100,
morph_plot_kwargs: Dict = {},
) -> Axes:
"""Plot the detailed morphology.
Expand All @@ -404,6 +418,10 @@ def plot_morph(
ax: The matplotlib axis to plot on.
morph_plot_kwargs: The plot kwargs for plt.fill.

resolution: defines the resolution of the mesh.
If too low (typically <10), can result in errors.
Useful too have a simpler mesh for plotting.

Returns:
Plot of the detailed morphology."""
if ax is None:
Expand All @@ -424,7 +442,12 @@ def plot_morph(
dxyz = xyzr2[:3] - xyzr1[:3]
length = np.sqrt(np.sum(dxyz**2))
points = create_cone_frustum_mesh(
length, xyzr1[-1], xyzr2[-1], bottom_dome=True, top_dome=True
length,
xyzr1[-1],
xyzr2[-1],
bottom_dome=True,
top_dome=True,
resolution=resolution,
)
plot_mesh(
points,
Expand All @@ -437,7 +460,12 @@ def plot_morph(
)
else:
points = create_cone_frustum_mesh(
0, xyzr[:, -1], xyzr[:, -1], bottom_dome=True, top_dome=True
0,
xyzr[:, -1],
xyzr[:, -1],
bottom_dome=True,
top_dome=True,
resolution=resolution,
)
plot_mesh(
points,
Expand Down
119 changes: 61 additions & 58 deletions tests/test_plotting_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,58 @@
from jaxley.synapses import IonotropicSynapse


def test_cell():
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", "morph.swc")
cell = jx.read_swc(fname, nseg=4)
@pytest.fixture(scope="module")
Copy link
Contributor

@jnsbck jnsbck Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! I think we could even move them to a separate file, to make them available to all other tests. tests/conftest.py.

Copy link
Contributor

@jnsbck jnsbck Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to replace as many of the other tests that rely on comp or net etc. with these fixtures. But also fine if you don't and then just leave #449 open and just add a quick comment to it.

EDIT: I have implemented something like this in #499 :)

def comp() -> jx.Compartment:
comp = jx.Compartment()
comp.compute_xyz()
return comp


@pytest.fixture(scope="module")
def branch(comp) -> jx.Branch:
branch = jx.Branch(comp, 4)
branch.compute_xyz()
return branch


@pytest.fixture(scope="module")
def cell(branch) -> jx.Cell:
cell = jx.Cell(branch, [-1, 0, 0, 1, 1])
cell.compute_xyz()
return cell


@pytest.fixture(scope="module")
def simple_net(cell) -> jx.Network:
net = jx.Network([cell] * 4)
net.compute_xyz()
return net


@pytest.fixture(scope="module")
def morph_cell() -> jx.Cell:
morph_cell = jx.read_swc(
os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"),
nseg=1,
)
return morph_cell


def test_cell(morph_cell):
# Plot 1.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
ax = cell.vis(ax=ax)
ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r")
ax = cell.branch(1).loc(0.9).vis(ax=ax, col="b")
ax = morph_cell.vis(ax=ax)
ax = morph_cell.branch([0, 1, 2]).vis(ax=ax, col="r")
ax = morph_cell.branch(1).loc(0.9).vis(ax=ax, col="b")

# Plot 2.
cell.branch(0).add_to_group("soma")
cell.branch(1).add_to_group("soma")
ax = cell.soma.vis()

morph_cell.branch(0).add_to_group("soma")
morph_cell.branch(1).add_to_group("soma")
ax = morph_cell.soma.vis()

def test_network():
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", "morph.swc")
cell1 = jx.read_swc(fname, nseg=4)
cell2 = jx.read_swc(fname, nseg=4)
cell3 = jx.read_swc(fname, nseg=4)

net = jx.Network([cell1, cell2, cell3])
def test_network(morph_cell):
net = jx.Network([morph_cell, morph_cell, morph_cell])
connect(
net.cell(0).branch(0).loc(0.0),
net.cell(1).branch(0).loc(0.0),
Expand Down Expand Up @@ -81,11 +108,7 @@ def test_network():
ax = net.excitatory.vis()


def test_vis_networks_built_from_scartch():
comp = jx.Compartment()
branch = jx.Branch(comp, 4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])

def test_vis_networks_built_from_scratch(comp, branch, cell):
net = jx.Network([cell, cell])
connect(
net.cell(0).branch(0).loc(0.0),
Expand All @@ -110,25 +133,15 @@ def test_vis_networks_built_from_scartch():

# Plot 3.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
comp.compute_xyz()
ax = comp.vis(ax=ax)

# Plot 4.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
branch.compute_xyz()
ax = branch.vis(ax=ax)


def test_mixed_network():
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", "morph.swc")
cell1 = jx.read_swc(fname, nseg=4)

comp = jx.Compartment()
branch = jx.Branch(comp, 4)
cell2 = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])

net = jx.Network([cell1, cell2])
def test_mixed_network(morph_cell, cell):
net = jx.Network([morph_cell, cell])
connect(
net.cell(0).branch(0).loc(0.0),
net.cell(1).branch(0).loc(0.0),
Expand All @@ -145,9 +158,9 @@ def test_mixed_network():
net.cell(1).move(0, -800)
net.rotate(180)

before_xyzrs = deepcopy(net.xyzr[len(cell1.xyzr) :])
before_xyzrs = deepcopy(net.xyzr[len(morph_cell.xyzr) :])
net.cell(1).rotate(90)
after_xyzrs = net.xyzr[len(cell1.xyzr) :]
after_xyzrs = net.xyzr[len(morph_cell.xyzr) :]
# Test that rotation worked as expected.
for b, a in zip(before_xyzrs, after_xyzrs):
assert np.allclose(b[:, 0], -a[:, 1], atol=1e-6)
Expand All @@ -156,34 +169,24 @@ def test_mixed_network():
_ = net.vis(detail="full")


def test_volume_plotting():
comp = jx.Compartment()
comp.compute_xyz()
branch = jx.Branch(comp, 4)
branch.compute_xyz()
cell = jx.Cell([branch] * 3, [-1, 0, 0])
cell.compute_xyz()
net = jx.Network([cell] * 4)
net.compute_xyz()

morph_cell = jx.read_swc(
os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"),
nseg=1,
)

def test_volume_plotting_2d(comp, branch, cell, simple_net, morph_cell):
fig, ax = plt.subplots()
for module in [comp, branch, cell, net, morph_cell]:
module.vis(type="comp", ax=ax)
for module in [comp, branch, cell, simple_net, morph_cell]:
module.vis(type="comp", ax=ax, morph_plot_kwargs={"resolution": 6})
plt.close(fig)


def test_volume_plotting_3d(comp, branch, cell, simple_net, morph_cell):
# test 3D plotting
for module in [comp, branch, cell, net, morph_cell]:
module.vis(type="comp", dims=[0, 1, 2])
for module in [comp, branch, cell, simple_net, morph_cell]:
module.vis(type="comp", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6})
plt.close()


def test_morph_plotting(morph_cell):
# test morph plotting (does not work if no radii in xyzr)
morph_cell.vis(type="morph")
morph_cell.vis(type="morph", morph_plot_kwargs={"resolution": 6})
morph_cell.branch(1).vis(
type="morph", dims=[0, 1, 2]
type="morph", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6}
) # plotting whole thing takes too long
plt.close()
2 changes: 1 addition & 1 deletion tests/test_swc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_swc_voltages(file):
for i in trunk_inds + tuft_inds + basal_inds:
cell.branch(i).loc(0.05).record()

voltages_jaxley = jx.integrate(cell, delta_t=dt)
voltages_jaxley = jx.integrate(cell, delta_t=dt, voltage_solver="jax.sparse")

################### NEURON #################
stim = h.IClamp(h.soma[0](0.1))
Expand Down
Loading