Skip to content
177 changes: 99 additions & 78 deletions pyneuroml/plot/PlotMorphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pyneuroml.pynml import read_neuroml2_file
from pyneuroml.utils.cli import build_namespace

import neuroml


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -172,97 +174,116 @@ def plot_2D(
if verbose:
print("Plotting %s" % nml_file)

nml_model = read_neuroml2_file(nml_file)
nml_model = read_neuroml2_file(nml_file,
include_includes=True,
check_validity_pre_include=False,
verbose=False,
optimized=True,)

for cell in nml_model.cells:

title = "2D plot of %s from %s" % (cell.id, nml_file)

fig, ax = plt.subplots(1, 1) # noqa
plt.get_current_fig_manager().set_window_title(title)

ax.set_aspect("equal")

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")


if plane2d == "xy":
ax.set_xlabel("x (μm)")
ax.set_ylabel("y (μm)")
elif plane2d == "yx":
ax.set_xlabel("y (μm)")
ax.set_ylabel("x (μm)")
elif plane2d == "xz":
ax.set_xlabel("x (μm)")
ax.set_ylabel("z (μm)")
elif plane2d == "zx":
ax.set_xlabel("z (μm)")
ax.set_ylabel("x (μm)")
elif plane2d == "yz":
ax.set_xlabel("y (μm)")
ax.set_ylabel("z (μm)")
elif plane2d == "zy":
ax.set_xlabel("z (μm)")
ax.set_ylabel("y (μm)")
else:
logger.error(f"Invalid value for plane: {plane2d}")
sys.exit(-1)

max_xaxis = -1*float('inf')
min_xaxis = float('inf')

try:
soma_segs = cell.get_all_segments_in_group('soma_group')
except:
soma_segs = []
try:
dend_segs = cell.get_all_segments_in_group('dendrite_group')
except:
dend_segs = []
try:
axon_segs = cell.get_all_segments_in_group('axon_group')
except:
axon_segs = []
from pyneuroml.utils import extract_position_info

for seg in cell.morphology.segments:
p = cell.get_actual_proximal(seg.id)
d = seg.distal
width = (p.diameter + d.diameter)/2
cell_id_vs_cell, pop_id_vs_cell, positions, pop_id_vs_color = extract_position_info(nml_model, verbose)

if width < min_width:
width = min_width
title = "2D plot of %s from %s" % (nml_model.networks[0].id, nml_file)

color = 'b'
if seg.id in soma_segs: color = 'g'
if seg.id in axon_segs: color = 'r'
if verbose:
print("positions: %s"%positions)
print("pop_id_vs_cell: %s"%pop_id_vs_cell)
print("cell_id_vs_cell: %s"%cell_id_vs_cell)
print("pop_id_vs_color: %s"%pop_id_vs_color)

fig, ax = plt.subplots(1, 1) # noqa
plt.get_current_fig_manager().set_window_title(title)

ax.set_aspect("equal")

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

if plane2d == "xy":
ax.set_xlabel("x (μm)")
ax.set_ylabel("y (μm)")
elif plane2d == "yx":
ax.set_xlabel("y (μm)")
ax.set_ylabel("x (μm)")
elif plane2d == "xz":
ax.set_xlabel("x (μm)")
ax.set_ylabel("z (μm)")
elif plane2d == "zx":
ax.set_xlabel("z (μm)")
ax.set_ylabel("x (μm)")
elif plane2d == "yz":
ax.set_xlabel("y (μm)")
ax.set_ylabel("z (μm)")
elif plane2d == "zy":
ax.set_xlabel("z (μm)")
ax.set_ylabel("y (μm)")
else:
logger.error(f"Invalid value for plane: {plane2d}")
sys.exit(-1)

max_xaxis = -1*float('inf')
min_xaxis = float('inf')

for pop_id in pop_id_vs_cell:
cell = pop_id_vs_cell[pop_id]
pos_pop = positions[pop_id]

for cell_index in pos_pop:
pos = pos_pop[cell_index]

try:
soma_segs = cell.get_all_segments_in_group('soma_group')
except:
soma_segs = []
try:
dend_segs = cell.get_all_segments_in_group('dendrite_group')
except:
dend_segs = []
try:
axon_segs = cell.get_all_segments_in_group('axon_group')
except:
axon_segs = []

for seg in cell.morphology.segments:
p = cell.get_actual_proximal(seg.id)
d = seg.distal
width = (p.diameter + d.diameter)/2

if width < min_width:
width = min_width

color = 'b'
if pop_id in pop_id_vs_color:
color = pop_id_vs_color[pop_id]
else:
if seg.id in soma_segs: color = 'g'
if seg.id in axon_segs: color = 'r'

spherical = p.x == d.x and p.y == d.y and p.z == d.z and p.diameter == d.diameter
spherical = p.x == d.x and p.y == d.y and p.z == d.z and p.diameter == d.diameter

if verbose:
print(
"\nSeg %s, id: %s%s has proximal: %s, distal: %s (width: %s, min_width: %s), color: %s"
% (seg.name, seg.id, ' (spherical)' if spherical else '', p, d, width, min_width, str(color))
)

if verbose:
print(
"\nSeg %s, id: %s%s has proximal: %s, distal: %s (width: %s, min_width: %s), color: %s"
% (seg.name, seg.id, ' (spherical)' if spherical else '', p, d, width, min_width, color)
)

if spherical:
if verbose: print("Segment is spherical")
ax.add_line(LineDataUnits([p.x+width/1000., d.x], [p.y, d.y+width/1000.], linewidth=width, solid_capstyle='round',color='r'))
else:
if plane2d == "xy":
min_xaxis, max_xaxis = add_line(ax, [p.x, d.x], [p.y, d.y], width, color, min_xaxis, max_xaxis)
min_xaxis, max_xaxis = add_line(ax, [pos[0]+p.x, pos[0]+d.x], [pos[1]+p.y, pos[1]+d.y], width, color, min_xaxis, max_xaxis)
elif plane2d == "yx":
min_xaxis, max_xaxis = add_line(ax, [p.y, d.y], [p.x, d.x], width, color, min_xaxis, max_xaxis)
min_xaxis, max_xaxis = add_line(ax, [pos[1]+p.y, pos[1]+d.y], [pos[0]+p.x, pos[0]+d.x], width, color, min_xaxis, max_xaxis)
elif plane2d == "xz":
min_xaxis, max_xaxis = add_line(ax, [p.x, d.x], [p.z, d.z], width, color, min_xaxis, max_xaxis)
min_xaxis, max_xaxis = add_line(ax, [pos[0]+p.x, pos[0]+d.x], [pos[2]+p.z, pos[2]+d.z], width, color, min_xaxis, max_xaxis)
elif plane2d == "zx":
min_xaxis, max_xaxis = add_line(ax, [p.z, d.z], [p.x, d.x], width, color, min_xaxis, max_xaxis)
min_xaxis, max_xaxis = add_line(ax, [pos[2]+p.z, pos[2]+d.z], [pos[0]+p.x, pos[0]+d.x], width, color, min_xaxis, max_xaxis)
elif plane2d == "yz":
min_xaxis, max_xaxis = add_line(ax, [p.y, d.y], [p.z, d.z], width, color, min_xaxis, max_xaxis)
min_xaxis, max_xaxis = add_line(ax, [pos[1]+p.y, pos[1]+d.y], [pos[2]+p.z, pos[2]+d.z], width, color, min_xaxis, max_xaxis)
elif plane2d == "zy":
min_xaxis, max_xaxis = add_line(ax, [p.z, d.z], [p.y, d.y], width, color, min_xaxis, max_xaxis)
min_xaxis, max_xaxis = add_line(ax, [pos[2]+p.z, pos[2]+d.z], [pos[1]+p.y, pos[1]+d.y], width, color, min_xaxis, max_xaxis)
else:
logger.error(f"Invalid value for plane: {plane2d}")
sys.exit(-1)
Expand Down Expand Up @@ -293,7 +314,7 @@ def plot_2D(

def add_line(ax, xv, yv, width, color, min_xaxis, max_xaxis):

if abs(xv[0]-xv[1])<0.01 and abs(yv[0]-yv[1])<0.01: # looking at the cylinder from the top, so draw a circle
if abs(xv[0]-xv[1])<0.01 and abs(yv[0]-yv[1])<0.01: # looking at the cylinder from the top, OR a sphere, so draw a circle
xv[1]=xv[1]+width/1000.
yv[1]=yv[1]+width/1000.

Expand Down
87 changes: 87 additions & 0 deletions pyneuroml/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import neuroml

def extract_position_info(nml_model, verbose):

cell_id_vs_cell = {}
positions = {}
pop_id_vs_cell = {}
pop_id_vs_color = {}

cell_elements = []
cell_elements.extend(nml_model.cells)
cell_elements.extend(nml_model.cell2_ca_poolses)

for cell in cell_elements:
cell_id_vs_cell[cell.id] = cell

if len(nml_model.networks)>0:
popElements = nml_model.networks[0].populations
else:
popElements = []
net = neuroml.Network(id='x')
nml_model.networks.append(net)
cell_str = ''
for cell in cell_elements:
pop = neuroml.Population(id='dummy_population_%s'%cell.id, size=1, component=cell.id)
net.populations.append(pop)
cell_str+=cell.id+'__'
net.id=cell_str[:-2]

popElements = nml_model.networks[0].populations


for pop in popElements:
name = pop.id
celltype = pop.component
instances = pop.instances

if pop.component in cell_id_vs_cell.keys():
pop_id_vs_cell[pop.id] = cell_id_vs_cell[pop.component]

info = "Population: %s has %i positioned cells of type: %s" % (
name,
len(instances),
celltype,
)
if verbose: print(info)

colour = "b"
substitute_radius = None

props = []
props.extend(pop.properties)
''' TODO
if pop.annotation:
props.extend(pop.annotation.properties)'''

for prop in props:
print(prop)
if prop.tag == "color":
color = prop.value
color = (float(color.split(' ')[0]),
float(color.split(' ')[1]),
float(color.split(' ')[2]))

pop_id_vs_color[pop.id]=color
print("Colour determined to be: %s"%str(color))
if prop.tag == "radius":
substitute_radius = float(prop.value)

pop_positions = {}

if len(instances)>0:
for instance in instances:
location = instance.location
id = int(instance.id)

x = float(location.x)
y = float(location.y)
z = float(location.z)
pop_positions[id] = (x, y, z)
else:
for id in range(pop.size):
pop_positions[id] = (0,0,0)

positions[name] = pop_positions

return cell_id_vs_cell, pop_id_vs_cell, positions, pop_id_vs_color