Skip to content
2 changes: 1 addition & 1 deletion pyneuroml/plot/Plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def generate_interactive_plot(

for i in range(len(xvalues)):
fig.add_trace(
go.Scatter(
go.Scattergl(
x=xvalues[i],
y=yvalues[i],
name=labels[i],
Expand Down
206 changes: 163 additions & 43 deletions pyneuroml/plot/PlotMorphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pyneuroml.pynml import read_neuroml2_file
from pyneuroml.utils.cli import build_namespace

import neuroml


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -33,7 +35,7 @@
"saveToFile": None,
"interactive3d": False,
"plane2d": "xy",
"minwidth": 0.8,
"minwidth": 0,
}


Expand Down Expand Up @@ -65,7 +67,8 @@ def process_args():

parser.add_argument(
"-plane2d",
action="store_true",
type=str,
metavar="<plane, e.g. xy, yz, zx>",
default=DEFAULTS["plane2d"],
help="Plane to plot on for 2D plot",
)
Expand Down Expand Up @@ -116,11 +119,35 @@ def plot_from_console(a: typing.Optional[typing.Any] = None, **kwargs: str):
else:
plot_2D(a.nml_file, a.plane2d, a.minwidth, a.v, a.nogui, a.save_to_file)

##########################################################################################
# Taken from https://stackoverflow.com/questions/19394505/expand-the-line-with-specified-width-in-data-unit
from matplotlib.lines import Line2D

class LineDataUnits(Line2D):
def __init__(self, *args, **kwargs):
_lw_data = kwargs.pop("linewidth", 1)
super().__init__(*args, **kwargs)
self._lw_data = _lw_data

def _get_lw(self):
if self.axes is not None:
ppd = 72./self.axes.figure.dpi
trans = self.axes.transData.transform
return ((trans((1, self._lw_data))-trans((0, 0)))*ppd)[1]
else:
return 1

def _set_lw(self, lw):
self._lw_data = lw

_linewidth = property(_get_lw, _set_lw)

##########################################################################################

def plot_2D(
nml_file: str,
plane2d: str = "xy",
min_width: float = 0.8,
min_width: float = DEFAULTS["minwidth"],
verbose: bool = False,
nogui: bool = False,
save_to_file: typing.Optional[str] = None
Expand All @@ -147,66 +174,159 @@ def plot_2D(
if verbose:
print("Plotting %s" % nml_file)

nml_model = read_neuroml2_file(nml_file)
nml_model = read_neuroml2_file(nml_file,
include_includes=True,
check_validity_pre_include=False,
verbose=False,
optimized=True,)

for cell in nml_model.cells:

title = "2D plot of %s from %s" % (cell.id, nml_file)
from pyneuroml.utils import extract_position_info

cell_id_vs_cell, pop_id_vs_cell, positions, pop_id_vs_color = extract_position_info(nml_model, verbose)

fig, ax = plt.subplots(1, 1) # noqa
plt.get_current_fig_manager().set_window_title(title)
title = "2D plot of %s from %s" % (nml_model.networks[0].id, nml_file)

ax.set_aspect("equal")
ax.set_xlabel("extent (um)")
ax.set_ylabel("extent (um)")
if verbose:
print("positions: %s"%positions)
print("pop_id_vs_cell: %s"%pop_id_vs_cell)
print("cell_id_vs_cell: %s"%cell_id_vs_cell)
print("pop_id_vs_color: %s"%pop_id_vs_color)

fig, ax = plt.subplots(1, 1) # noqa
plt.get_current_fig_manager().set_window_title(title)

ax.set_aspect("equal")

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

if plane2d == "xy":
ax.set_xlabel("x (μm)")
ax.set_ylabel("y (μm)")
elif plane2d == "yx":
ax.set_xlabel("y (μm)")
ax.set_ylabel("x (μm)")
elif plane2d == "xz":
ax.set_xlabel("x (μm)")
ax.set_ylabel("z (μm)")
elif plane2d == "zx":
ax.set_xlabel("z (μm)")
ax.set_ylabel("x (μm)")
elif plane2d == "yz":
ax.set_xlabel("y (μm)")
ax.set_ylabel("z (μm)")
elif plane2d == "zy":
ax.set_xlabel("z (μm)")
ax.set_ylabel("y (μm)")
else:
logger.error(f"Invalid value for plane: {plane2d}")
sys.exit(-1)

max_xaxis = -1*float('inf')
min_xaxis = float('inf')

for pop_id in pop_id_vs_cell:
cell = pop_id_vs_cell[pop_id]
pos_pop = positions[pop_id]

for cell_index in pos_pop:
pos = pos_pop[cell_index]

try:
soma_segs = cell.get_all_segments_in_group('soma_group')
except:
soma_segs = []
try:
dend_segs = cell.get_all_segments_in_group('dendrite_group')
except:
dend_segs = []
try:
axon_segs = cell.get_all_segments_in_group('axon_group')
except:
axon_segs = []

for seg in cell.morphology.segments:
p = cell.get_actual_proximal(seg.id)
d = seg.distal
width = (p.diameter + d.diameter)/2

if width < min_width:
width = min_width

color = 'b'
if pop_id in pop_id_vs_color:
color = pop_id_vs_color[pop_id]
else:
if seg.id in soma_segs: color = 'g'
if seg.id in axon_segs: color = 'r'

spherical = p.x == d.x and p.y == d.y and p.z == d.z and p.diameter == d.diameter

if verbose:
print(
"\nSeg %s, id: %s%s has proximal: %s, distal: %s (width: %s, min_width: %s), color: %s"
% (seg.name, seg.id, ' (spherical)' if spherical else '', p, d, width, min_width, str(color))
)

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position("left")
ax.xaxis.set_ticks_position("bottom")

if plane2d == "xy":
min_xaxis, max_xaxis = add_line(ax, [pos[0]+p.x, pos[0]+d.x], [pos[1]+p.y, pos[1]+d.y], width, color, min_xaxis, max_xaxis)
elif plane2d == "yx":
min_xaxis, max_xaxis = add_line(ax, [pos[1]+p.y, pos[1]+d.y], [pos[0]+p.x, pos[0]+d.x], width, color, min_xaxis, max_xaxis)
elif plane2d == "xz":
min_xaxis, max_xaxis = add_line(ax, [pos[0]+p.x, pos[0]+d.x], [pos[2]+p.z, pos[2]+d.z], width, color, min_xaxis, max_xaxis)
elif plane2d == "zx":
min_xaxis, max_xaxis = add_line(ax, [pos[2]+p.z, pos[2]+d.z], [pos[0]+p.x, pos[0]+d.x], width, color, min_xaxis, max_xaxis)
elif plane2d == "yz":
min_xaxis, max_xaxis = add_line(ax, [pos[1]+p.y, pos[1]+d.y], [pos[2]+p.z, pos[2]+d.z], width, color, min_xaxis, max_xaxis)
elif plane2d == "zy":
min_xaxis, max_xaxis = add_line(ax, [pos[2]+p.z, pos[2]+d.z], [pos[1]+p.y, pos[1]+d.y], width, color, min_xaxis, max_xaxis)
else:
logger.error(f"Invalid value for plane: {plane2d}")
sys.exit(-1)

if verbose: print('Extent x: %s -> %s'%(min_xaxis, max_xaxis))
# add a scalebar
# ax = fig.add_axes([0, 0, 1, 1])
sc_val = 50
if max_xaxis-min_xaxis<100:
sc_val = 5
if max_xaxis-min_xaxis<10:
sc_val = 1
scalebar1 = ScaleBar(0.001, units="mm", dimension="si-length",
scale_loc="top", location="lower right",
fixed_value=50, fixed_units="um")
fixed_value=sc_val, fixed_units="um", box_alpha=.8)
ax.add_artist(scalebar1)

for seg in cell.morphology.segments:
p = cell.get_actual_proximal(seg.id)
d = seg.distal
if verbose:
print(
"\nSegment %s, id: %s has proximal point: %s, distal: %s"
% (seg.name, seg.id, p, d)
)
width = max(p.diameter, d.diameter)
if width < min_width:
width = min_width
if plane2d == "xy":
plt.plot([p.x, d.x], [p.y, d.y], linewidth=width, color="b")
elif plane2d == "yx":
plt.plot([p.y, d.y], [p.x, d.x], linewidth=width, color="b")
elif plane2d == "xz":
plt.plot([p.x, d.x], [p.z, d.z], linewidth=width, color="b")
elif plane2d == "zx":
plt.plot([p.z, d.z], [p.x, d.x], linewidth=width, color="b")
elif plane2d == "yz":
plt.plot([p.y, d.y], [p.z, d.z], linewidth=width, color="b")
elif plane2d == "zy":
plt.plot([p.z, d.z], [p.y, d.y], linewidth=width, color="b")
else:
logger.error(f"Invalid value for plane: {plane2d}")
sys.exit(-1)
plt.autoscale()


if save_to_file:
abs_file = os.path.abspath(save_to_file)
plt.savefig(abs_file, dpi=200)
plt.savefig(abs_file, dpi=200, bbox_inches="tight")
print(f"Saved image on plane {plane2d} to {abs_file} of plot: {title}")

if not nogui:
plt.show()

def add_line(ax, xv, yv, width, color, min_xaxis, max_xaxis):

if abs(xv[0]-xv[1])<0.01 and abs(yv[0]-yv[1])<0.01: # looking at the cylinder from the top, OR a sphere, so draw a circle
xv[1]=xv[1]+width/1000.
yv[1]=yv[1]+width/1000.

ax.add_line(LineDataUnits(xv, yv, linewidth=width, solid_capstyle='round',color=color))

ax.add_line(LineDataUnits(xv, yv, linewidth=width, solid_capstyle='butt', color=color))

min_xaxis=min(min_xaxis,xv[0])
min_xaxis=min(min_xaxis,xv[1])
max_xaxis=max(max_xaxis,xv[0])
max_xaxis=max(max_xaxis,xv[1])
return min_xaxis, max_xaxis

def plot_interactive_3D(
nml_file: str,
Expand Down
87 changes: 87 additions & 0 deletions pyneuroml/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import neuroml

def extract_position_info(nml_model, verbose):

cell_id_vs_cell = {}
positions = {}
pop_id_vs_cell = {}
pop_id_vs_color = {}

cell_elements = []
cell_elements.extend(nml_model.cells)
cell_elements.extend(nml_model.cell2_ca_poolses)

for cell in cell_elements:
cell_id_vs_cell[cell.id] = cell

if len(nml_model.networks)>0:
popElements = nml_model.networks[0].populations
else:
popElements = []
net = neuroml.Network(id='x')
nml_model.networks.append(net)
cell_str = ''
for cell in cell_elements:
pop = neuroml.Population(id='dummy_population_%s'%cell.id, size=1, component=cell.id)
net.populations.append(pop)
cell_str+=cell.id+'__'
net.id=cell_str[:-2]

popElements = nml_model.networks[0].populations


for pop in popElements:
name = pop.id
celltype = pop.component
instances = pop.instances

if pop.component in cell_id_vs_cell.keys():
pop_id_vs_cell[pop.id] = cell_id_vs_cell[pop.component]

info = "Population: %s has %i positioned cells of type: %s" % (
name,
len(instances),
celltype,
)
if verbose: print(info)

colour = "b"
substitute_radius = None

props = []
props.extend(pop.properties)
''' TODO
if pop.annotation:
props.extend(pop.annotation.properties)'''

for prop in props:
print(prop)
if prop.tag == "color":
color = prop.value
color = (float(color.split(' ')[0]),
float(color.split(' ')[1]),
float(color.split(' ')[2]))

pop_id_vs_color[pop.id]=color
print("Colour determined to be: %s"%str(color))
if prop.tag == "radius":
substitute_radius = float(prop.value)

pop_positions = {}

if len(instances)>0:
for instance in instances:
location = instance.location
id = int(instance.id)

x = float(location.x)
y = float(location.y)
z = float(location.z)
pop_positions[id] = (x, y, z)
else:
for id in range(pop.size):
pop_positions[id] = (0,0,0)

positions[name] = pop_positions

return cell_id_vs_cell, pop_id_vs_cell, positions, pop_id_vs_color
12 changes: 6 additions & 6 deletions tests/plot/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,21 @@ def test_generate_interactive_plot(self):
except FileNotFoundError:
pass

xs = [*range(5, 15)]
ys = [*range(5, 15)]
xs1 = [*range(5, 15)]
ys1 = [*range(14, 4, -1)]
xs = [*range(5, 15000)]
ys = [*range(5, 15000)]
xs1 = [*range(5, 15000)]
ys1 = [*range(14999, 4, -1)]
labels = ["up", "down"]
generate_interactive_plot(
xvalues=[xs, xs1],
yvalues=[ys, ys1],
labels=labels,
modes=["lines+markers", "markers"],
title="test interactive plot",
title=f"test interactive plot with {len(xs) * 2} points",
plot_bgcolor="white",
xaxis="x axis",
yaxis="y axis",
show_interactive=False,
show_interactive=True,
xaxis_spikelines=True,
yaxis_spikelines=False,
save_figure_to=filename,
Expand Down