diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f99ce792..8b29368f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,7 +60,6 @@ jobs: - name: Run tests run: | - pytest --cov=pyneuroml -m "not localonly" . pynml -h ./test-ghactions.sh -neuron diff --git a/pyneuroml/__init__.py b/pyneuroml/__init__.py index 99f16415..b765b4e0 100644 --- a/pyneuroml/__init__.py +++ b/pyneuroml/__init__.py @@ -1,6 +1,6 @@ import logging -__version__ = "0.7.6" +__version__ = "1.0.0" JNEUROML_VERSION = "0.12.1" diff --git a/pyneuroml/analysis/__init__.py b/pyneuroml/analysis/__init__.py index 38768419..22ca9bbb 100644 --- a/pyneuroml/analysis/__init__.py +++ b/pyneuroml/analysis/__init__.py @@ -7,6 +7,7 @@ from pyneuroml import pynml from pyneuroml.lems.LEMSSimulation import LEMSSimulation from pyneuroml.lems import generate_lems_file_for_neuroml +from pyneuroml.utils.plot import get_next_hex_color import neuroml as nml from pyelectro.analysis import max_min from pyelectro.analysis import mean_spike_frequency @@ -240,7 +241,7 @@ def generate_current_vs_frequency_curve( for i in range(number_cells): ref = "v_cell%i" % i quantity = "%s[%i]/v" % (pop.id, i) - ls.add_line_to_display(disp0, ref, quantity, "1mV", pynml.get_next_hex_color()) + ls.add_line_to_display(disp0, ref, quantity, "1mV", get_next_hex_color()) ls.add_column_to_output_file(of0, ref, quantity) lems_file_name = ls.save_to_file() diff --git a/pyneuroml/lems/LEMSSimulation.py b/pyneuroml/lems/LEMSSimulation.py index b1bad303..77d12e43 100644 --- a/pyneuroml/lems/LEMSSimulation.py +++ b/pyneuroml/lems/LEMSSimulation.py @@ -13,7 +13,7 @@ from neuroml import __version__ as libnml_ver from pyneuroml.pynml import read_neuroml2_file from pyneuroml.pynml import read_lems_file -from pyneuroml.pynml import get_next_hex_color +from pyneuroml.utils.plot import get_next_hex_color logger = logging.getLogger(__name__) diff --git a/pyneuroml/lems/__init__.py b/pyneuroml/lems/__init__.py index bc99316e..1ca81496 100644 --- a/pyneuroml/lems/__init__.py +++ b/pyneuroml/lems/__init__.py @@ -4,7 +4,8 @@ import shutil import os import logging -from pyneuroml.pynml import read_neuroml2_file, get_next_hex_color +from pyneuroml.pynml import read_neuroml2_file +from pyneuroml.utils.plot import get_next_hex_color import random import neuroml diff --git a/pyneuroml/plot/Plot.py b/pyneuroml/plot/Plot.py index 15bfd8ae..0e84a88f 100644 --- a/pyneuroml/plot/Plot.py +++ b/pyneuroml/plot/Plot.py @@ -45,7 +45,8 @@ def generate_plot( save_figure_to: typing.Optional[str] = None, title_above_plot: bool = False, verbose: bool = False, -) -> matplotlib.axes.Axes: + close_plot: bool = False, +) -> typing.Optional[matplotlib.axes.Axes]: """Utility function to generate plots using the Matplotlib library. This function can be used to generate graphs with multiple plot lines. @@ -127,7 +128,9 @@ def generate_plot( :type title_above_plot: boolean :param verbose: enable/disable verbose logging (default: False) :type verbose: boolean - :returns: matplotlib.axes.Axes object + :param close_plot: call pyplot.close() to close plot after plotting + :type close_plot: bool + :returns: matplotlib.axes.Axes object if plot is not closed, else None """ logger.info("Generating plot: %s" % (title)) @@ -239,7 +242,13 @@ def generate_plot( if show_plot_already: plt.show() - return ax + if close_plot: + logger.info("Closing plot") + plt.close() + else: + return ax + + return None def generate_interactive_plot( diff --git a/pyneuroml/plot/PlotMorphology.py b/pyneuroml/plot/PlotMorphology.py index 7ec04e6c..7d3a17e9 100644 --- a/pyneuroml/plot/PlotMorphology.py +++ b/pyneuroml/plot/PlotMorphology.py @@ -11,17 +11,34 @@ import argparse import os import sys +import random import typing import logging +from vispy import app, scene +import numpy +import matplotlib from matplotlib import pyplot as plt -from matplotlib_scalebar.scalebar import ScaleBar import plotly.graph_objects as go from pyneuroml.pynml import read_neuroml2_file from pyneuroml.utils.cli import build_namespace from pyneuroml.utils import extract_position_info +from pyneuroml.utils.plot import ( + add_text_to_matplotlib_2D_plot, + add_text_to_vispy_3D_plot, + get_next_hex_color, + add_box_to_matplotlib_2D_plot, + get_new_matplotlib_morph_plot, + autoscale_matplotlib_plot, + add_scalebar_to_matplotlib_plot, + add_line_to_matplotlib_2D_plot, + create_new_vispy_canvas, + get_cell_bound_box, +) +from neuroml import SegmentGroup, Cell, Segment +from neuroml.neuro_lex_ids import neuro_lex_ids logger = logging.getLogger(__name__) @@ -34,9 +51,10 @@ "saveToFile": None, "interactive3d": False, "plane2d": "xy", - "minwidth": 0.8, + "minWidth": 0.8, "square": False, - "plotType": "Detailed" + "plotType": "Constant", + "theme": "light", } @@ -79,12 +97,20 @@ def process_args(): type=str, metavar="", default=DEFAULTS["plotType"], - help="Plane to plot on for 2D plot", + help="Level of detail to plot in", + ) + parser.add_argument( + "-theme", + type=str, + metavar="", + default=DEFAULTS["theme"], + help="Theme to use for interactive 3d plotting", ) parser.add_argument( "-minWidth", - action="store_true", - default=DEFAULTS["minwidth"], + type=float, + metavar="", + default=DEFAULTS["minWidth"], help="Minimum width of lines to use", ) @@ -100,14 +126,14 @@ def process_args(): type=str, metavar="", default=None, - help="Name of the image file", + help="Name of the image file, for 2D plot", ) parser.add_argument( "-square", action="store_true", default=DEFAULTS["square"], - help="Scale axes so that image is approximately square", + help="Scale axes so that image is approximately square, for 2D plot", ) return parser.parse_args() @@ -130,55 +156,235 @@ def plot_from_console(a: typing.Optional[typing.Any] = None, **kwargs: str): a = build_namespace(DEFAULTS, a, **kwargs) print(a) if a.interactive3d: - plot_interactive_3D(a.nml_file, a.minwidth, a.v, a.nogui, a.save_to_file) + plot_interactive_3D( + nml_file=a.nml_file, + min_width=a.min_width, + verbose=a.v, + plot_type=a.plot_type, + theme=a.theme, + ) else: plot_2D( - a.nml_file, a.plane2d, a.minwidth, a.v, a.nogui, a.save_to_file, - a.square, a.plot_type + a.nml_file, + a.plane2d, + a.min_width, + a.v, + a.nogui, + a.save_to_file, + a.square, + a.plot_type, + ) + + +def plot_interactive_3D( + nml_file: str, + min_width: float = DEFAULTS["minWidth"], + verbose: bool = False, + plot_type: str = "Constant", + title: typing.Optional[str] = None, + theme: str = "light", + nogui: bool = False, +): + """Plot interactive plots in 3D using Vispy + + https://vispy.org + + :param nml_file: path to NeuroML cell file + :type nml_file: str + :param min_width: minimum width for segments (useful for visualising very + thin segments): default 0.8um + :type min_width: float + :param verbose: show extra information (default: False) + :type verbose: bool + :param plot_type: type of plot, one of: + + - "Detailed": show detailed morphology taking into account each segment's + width + - "Constant": show morphology, but use constant line widths + - "Schematic": only plot each unbranched segment group as a straight + line, not following each segment + + This is only applicable for neuroml.Cell cells (ones with some + morphology) + + :type plot_type: str + :param title: title of plot + :type title: str + :param theme: theme to use (light/dark) + :type theme: str + :param nogui: toggle showing gui (for testing only) + :type nogui: bool + """ + if plot_type not in ["Detailed", "Constant", "Schematic"]: + raise ValueError( + "plot_type must be one of 'Detailed', 'Constant', or 'Schematic'" ) + if verbose: + print(f"Plotting {nml_file}") -########################################################################################## -# Taken from https://stackoverflow.com/questions/19394505/expand-the-line-with-specified-width-in-data-unit -from matplotlib.lines import Line2D + nml_model = read_neuroml2_file( + nml_file, + include_includes=True, + check_validity_pre_include=False, + verbose=False, + optimized=True, + ) + ( + cell_id_vs_cell, + pop_id_vs_cell, + positions, + pop_id_vs_color, + pop_id_vs_radii, + ) = extract_position_info(nml_model, verbose) -class LineDataUnits(Line2D): - def __init__(self, *args, **kwargs): - _lw_data = kwargs.pop("linewidth", 1) - super().__init__(*args, **kwargs) - self._lw_data = _lw_data + # Collect all markers and only plot one markers object + # this is more efficient than multiple markers, one for each point. + # TODO: also collect all line points and only use one object rather than a + # new object for each cell: will only work for the case where all lines + # have the same width + marker_sizes = [] + marker_points = [] + marker_colors = [] + + if title is None: + title = f"{nml_model.networks[0].id} from {nml_file}" + + logger.debug(f"positions: {positions}") + logger.debug(f"pop_id_vs_cell: {pop_id_vs_cell}") + logger.debug(f"cell_id_vs_cell: {cell_id_vs_cell}") + logger.debug(f"pop_id_vs_color: {pop_id_vs_color}") + logger.debug(f"pop_id_vs_radii: {pop_id_vs_radii}") + + if len(positions.keys()) > 1: + only_pos = [] + for posdict in positions.values(): + for poss in posdict.values(): + only_pos.append(poss) + + pos_array = numpy.array(only_pos) + center = numpy.array( + [ + numpy.mean(pos_array[:, 0]), + numpy.mean(pos_array[:, 1]), + numpy.mean(pos_array[:, 2]), + ] + ) + x_min = numpy.min(pos_array[:, 0]) + x_max = numpy.max(pos_array[:, 0]) + x_len = abs(x_max - x_min) + + y_min = numpy.min(pos_array[:, 1]) + y_max = numpy.max(pos_array[:, 1]) + y_len = abs(y_max - y_min) + + z_min = numpy.min(pos_array[:, 2]) + z_max = numpy.max(pos_array[:, 2]) + z_len = abs(z_max - z_min) - def _get_lw(self): - if self.axes is not None: - ppd = 72.0 / self.axes.figure.dpi - trans = self.axes.transData.transform - return ((trans((1, self._lw_data)) - trans((0, 0))) * ppd)[1] + view_min = center - numpy.array([x_len, y_len, z_len]) + view_max = center + numpy.array([x_len, y_len, z_len]) + logger.debug(f"center, view_min, max are {center}, {view_min}, {view_max}") + + else: + cell = list(pop_id_vs_cell.values())[0] + if cell is not None: + view_min, view_max = get_cell_bound_box(cell) else: - return 1 + logger.debug("Got a point cell") + pos = list((list(positions.values())[0]).values())[0] + view_min = list(numpy.array(pos)) + view_min = list(numpy.array(pos)) + + current_scene, current_view = create_new_vispy_canvas( + view_min, view_max, title, theme=theme + ) - def _set_lw(self, lw): - self._lw_data = lw + logger.debug(f"figure extents are: {view_min}, {view_max}") - _linewidth = property(_get_lw, _set_lw) + for pop_id in pop_id_vs_cell: + cell = pop_id_vs_cell[pop_id] + pos_pop = positions[pop_id] + for cell_index in pos_pop: + pos = pos_pop[cell_index] + radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10 + color = pop_id_vs_color[pop_id] if pop_id in pop_id_vs_color else None -########################################################################################## + try: + logging.info(f"Plotting {cell.id}") + except AttributeError: + logging.info(f"Plotting a point cell at {pos}") + + if cell is None: + marker_points.extend([pos]) + marker_sizes.extend([radius]) + marker_colors.extend([color]) + else: + if plot_type == "Schematic": + plot_3D_schematic( + offset=pos, + cell=cell, + segment_groups=None, + labels=True, + verbose=verbose, + current_scene=current_scene, + current_view=current_view, + nogui=True, + ) + else: + pts, sizes, colors = plot_3D_cell_morphology( + offset=pos, + cell=cell, + color=color, + plot_type=plot_type, + verbose=verbose, + current_scene=current_scene, + current_view=current_view, + min_width=min_width, + nogui=True, + ) + marker_points.extend(pts) + marker_sizes.extend(sizes) + marker_colors.extend(colors) + + if len(marker_points) > 0: + scene.Markers( + pos=numpy.array(marker_points), + size=numpy.array(marker_sizes), + spherical=True, + face_color=marker_colors, + edge_color=marker_colors, + edge_width=0, + parent=current_view.scene, + scaling=True, + antialias=0, + ) + if not nogui: + app.run() def plot_2D( nml_file: str, plane2d: str = "xy", - min_width: float = DEFAULTS["minwidth"], + min_width: float = DEFAULTS["minWidth"], verbose: bool = False, nogui: bool = False, save_to_file: typing.Optional[str] = None, square: bool = False, plot_type: str = "Detailed", + title: typing.Optional[str] = None, + close_plot: bool = False, ): - """Plot cell morphology in 2D. + """Plot cells in a 2D plane. + + If a file with a network containing multiple cells is provided, it will + plot all the cells. For detailed neuroml.Cell types, it will plot their + complete morphology. For point neurons, we only plot the points (locations) + where they are. - This uses matplotlib to plot the morphology in 2D. + This method uses matplotlib. :param nml_file: path to NeuroML cell file :type nml_file: str @@ -196,16 +402,27 @@ def plot_2D( :param square: scale axes so that image is approximately square :type square: bool :param plot_type: type of plot, one of: - - Detailed: show detailed morphology taking into account each segment's + + - "Detailed": show detailed morphology taking into account each segment's width - - Constant: show morphology, but use constant line widths - - Schematic: only plot each unbranched segment group as a straight + - "Constant": show morphology, but use constant line widths + - "Schematic": only plot each unbranched segment group as a straight line, not following each segment + + This is only applicable for neuroml.Cell cells (ones with some + morphology) + :type plot_type: str + :param title: title of plot + :type title: str + :param close_plot: call pyplot.close() to close plot after plotting + :type close_plot: bool """ if plot_type not in ["Detailed", "Constant", "Schematic"]: - raise ValueError("plot_type must be one of 'Detailed', 'Constant', or 'Schematic'") + raise ValueError( + "plot_type must be one of 'Detailed', 'Constant', or 'Schematic'" + ) if verbose: print("Plotting %s" % nml_file) @@ -226,7 +443,8 @@ def plot_2D( pop_id_vs_radii, ) = extract_position_info(nml_model, verbose) - title = "2D plot of %s from %s" % (nml_model.networks[0].id, nml_file) + if title is None: + title = "2D plot of %s from %s" % (nml_model.networks[0].id, nml_file) if verbose: logger.debug(f"positions: {positions}") @@ -235,40 +453,8 @@ def plot_2D( logger.debug(f"pop_id_vs_color: {pop_id_vs_color}") logger.debug(f"pop_id_vs_radii: {pop_id_vs_radii}") - fig, ax = plt.subplots(1, 1) # noqa - plt.get_current_fig_manager().set_window_title(title) - - ax.set_aspect("equal") - - ax.spines["right"].set_visible(False) - ax.spines["top"].set_visible(False) - ax.yaxis.set_ticks_position("left") - ax.xaxis.set_ticks_position("bottom") - - if plane2d == "xy": - ax.set_xlabel("x (μm)") - ax.set_ylabel("y (μm)") - elif plane2d == "yx": - ax.set_xlabel("y (μm)") - ax.set_ylabel("x (μm)") - elif plane2d == "xz": - ax.set_xlabel("x (μm)") - ax.set_ylabel("z (μm)") - elif plane2d == "zx": - ax.set_xlabel("z (μm)") - ax.set_ylabel("x (μm)") - elif plane2d == "yz": - ax.set_xlabel("y (μm)") - ax.set_ylabel("z (μm)") - elif plane2d == "zy": - ax.set_xlabel("z (μm)") - ax.set_ylabel("y (μm)") - else: - logger.error(f"Invalid value for plane: {plane2d}") - sys.exit(-1) - - max_xaxis = -1 * float("inf") - min_xaxis = float("inf") + fig, ax = get_new_matplotlib_morph_plot(title, plane2d) + axis_min_max = [float("inf"), -1 * float("inf")] for pop_id in pop_id_vs_cell: cell = pop_id_vs_cell[pop_id] @@ -276,248 +462,58 @@ def plot_2D( for cell_index in pos_pop: pos = pos_pop[cell_index] - - try: - soma_segs = cell.get_all_segments_in_group("soma_group") - except: - soma_segs = [] - try: - dend_segs = cell.get_all_segments_in_group("dendrite_group") - except: - dend_segs = [] - try: - axon_segs = cell.get_all_segments_in_group("axon_group") - except: - axon_segs = [] + radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10 + color = pop_id_vs_color[pop_id] if pop_id in pop_id_vs_color else None if cell is None: - - radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10 - color = "b" - if pop_id in pop_id_vs_color: - color = pop_id_vs_color[pop_id] - - if plane2d == "xy": - min_xaxis, max_xaxis = add_line( - ax, - [pos[0], pos[0]], - [pos[1], pos[1]], - radius, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "yx": - min_xaxis, max_xaxis = add_line( - ax, - [pos[1], pos[1]], - [pos[0], pos[0]], - radius, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "xz": - min_xaxis, max_xaxis = add_line( - ax, - [pos[0], pos[0]], - [pos[2], pos[2]], - radius, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "zx": - min_xaxis, max_xaxis = add_line( - ax, - [pos[2], pos[2]], - [pos[0], pos[0]], - radius, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "yz": - min_xaxis, max_xaxis = add_line( - ax, - [pos[1], pos[1]], - [pos[2], pos[2]], - radius, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "zy": - min_xaxis, max_xaxis = add_line( - ax, - [pos[2], pos[2]], - [pos[1], pos[1]], - radius, - color, - min_xaxis, - max_xaxis, + plot_2D_point_cells( + offset=pos, + plane2d=plane2d, + color=color, + soma_radius=radius, + verbose=verbose, + ax=ax, + fig=fig, + autoscale=False, + scalebar=False, + nogui=True, + ) + else: + if plot_type == "Schematic": + plot_2D_schematic( + offset=pos, + cell=cell, + segment_groups=None, + labels=True, + plane2d=plane2d, + verbose=verbose, + fig=fig, + ax=ax, + scalebar=False, + nogui=True, + autoscale=False, + square=False, ) else: - raise Exception(f"Invalid value for plane: {plane2d}") - - else: - - for seg in cell.morphology.segments: - p = cell.get_actual_proximal(seg.id) - d = seg.distal - width = (p.diameter + d.diameter) / 2 - - if width < min_width: - width = min_width - - if plot_type == "Constant": - width = min_width - - color = "b" - if pop_id in pop_id_vs_color: - color = pop_id_vs_color[pop_id] - else: - if seg.id in soma_segs: - color = "g" - if seg.id in axon_segs: - color = "r" - - spherical = ( - p.x == d.x - and p.y == d.y - and p.z == d.z - and p.diameter == d.diameter + plot_2D_cell_morphology( + offset=pos, + cell=cell, + plane2d=plane2d, + color=color, + plot_type=plot_type, + verbose=verbose, + fig=fig, + ax=ax, + min_width=min_width, + axis_min_max=axis_min_max, + scalebar=False, + nogui=True, + autoscale=False, + square=False, ) - if verbose: - print( - "\nSeg %s, id: %s%s has proximal: %s, distal: %s (width: %s, min_width: %s), color: %s" - % ( - seg.name, - seg.id, - " (spherical)" if spherical else "", - p, - d, - width, - min_width, - str(color), - ) - ) - - if plane2d == "xy": - min_xaxis, max_xaxis = add_line( - ax, - [pos[0] + p.x, pos[0] + d.x], - [pos[1] + p.y, pos[1] + d.y], - width, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "yx": - min_xaxis, max_xaxis = add_line( - ax, - [pos[1] + p.y, pos[1] + d.y], - [pos[0] + p.x, pos[0] + d.x], - width, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "xz": - min_xaxis, max_xaxis = add_line( - ax, - [pos[0] + p.x, pos[0] + d.x], - [pos[2] + p.z, pos[2] + d.z], - width, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "zx": - min_xaxis, max_xaxis = add_line( - ax, - [pos[2] + p.z, pos[2] + d.z], - [pos[0] + p.x, pos[0] + d.x], - width, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "yz": - min_xaxis, max_xaxis = add_line( - ax, - [pos[1] + p.y, pos[1] + d.y], - [pos[2] + p.z, pos[2] + d.z], - width, - color, - min_xaxis, - max_xaxis, - ) - elif plane2d == "zy": - min_xaxis, max_xaxis = add_line( - ax, - [pos[2] + p.z, pos[2] + d.z], - [pos[1] + p.y, pos[1] + d.y], - width, - color, - min_xaxis, - max_xaxis, - ) - else: - raise Exception(f"Invalid value for plane: {plane2d}") - - if verbose: - print("Extent x: %s -> %s" % (min_xaxis, max_xaxis)) - - # add a scalebar - # ax = fig.add_axes([0, 0, 1, 1]) - sc_val = 50 - if max_xaxis - min_xaxis < 100: - sc_val = 5 - if max_xaxis - min_xaxis < 10: - sc_val = 1 - scalebar1 = ScaleBar( - 0.001, - units="mm", - dimension="si-length", - scale_loc="top", - location="lower right", - fixed_value=sc_val, - fixed_units="um", - box_alpha=0.8, - ) - ax.add_artist(scalebar1) - - plt.autoscale() - xl = plt.xlim() - yl = plt.ylim() - if verbose: - print("Auto limits - x: %s , y: %s" % (xl, yl)) - - small = 0.1 - if xl[1] - xl[0] < small and yl[1] - yl[0] < small: # i.e. only a point - plt.xlim([-100, 100]) - plt.ylim([-100, 100]) - elif xl[1] - xl[0] < small: - d_10 = (yl[1] - yl[0]) / 10 - m = xl[0] + (xl[1] - xl[0]) / 2.0 - plt.xlim([m - d_10, m + d_10]) - elif yl[1] - yl[0] < small: - d_10 = (xl[1] - xl[0]) / 10 - m = yl[0] + (yl[1] - yl[0]) / 2.0 - plt.ylim([m - d_10, m + d_10]) - - if square: - if xl[1] - xl[0] > yl[1] - yl[0]: - d2 = (xl[1] - xl[0]) / 2 - m = yl[0] + (yl[1] - yl[0]) / 2.0 - plt.ylim([m - d2, m + d2]) - - if xl[1] - xl[0] < yl[1] - yl[0]: - d2 = (yl[1] - yl[0]) / 2 - m = xl[0] + (xl[1] - xl[0]) / 2.0 - plt.xlim([m - d2, m + d2]) + add_scalebar_to_matplotlib_plot(axis_min_max, ax) + autoscale_matplotlib_plot(verbose, square) if save_to_file: abs_file = os.path.abspath(save_to_file) @@ -526,38 +522,18 @@ def plot_2D( if not nogui: plt.show() + if close_plot: + logger.info("Closing plot") + plt.close() -def add_line(ax, xv, yv, width, color, min_xaxis, max_xaxis): - - if ( - abs(xv[0] - xv[1]) < 0.01 and abs(yv[0] - yv[1]) < 0.01 - ): # looking at the cylinder from the top, OR a sphere, so draw a circle - xv[1] = xv[1] + width / 1000.0 - yv[1] = yv[1] + width / 1000.0 - - ax.add_line( - LineDataUnits(xv, yv, linewidth=width, solid_capstyle="round", color=color) - ) - - ax.add_line( - LineDataUnits(xv, yv, linewidth=width, solid_capstyle="butt", color=color) - ) - - min_xaxis = min(min_xaxis, xv[0]) - min_xaxis = min(min_xaxis, xv[1]) - max_xaxis = max(max_xaxis, xv[0]) - max_xaxis = max(max_xaxis, xv[1]) - return min_xaxis, max_xaxis - - -def plot_interactive_3D( +def plot_3D_cell_morphology_plotly( nml_file: str, min_width: float = 0.8, verbose: bool = False, nogui: bool = False, save_to_file: typing.Optional[str] = None, - plot_type: str = "Detailed" + plot_type: str = "Detailed", ): """Plot NeuroML2 cell morphology interactively using Plot.ly @@ -580,15 +556,17 @@ def plot_interactive_3D( :param save_to_file: optional filename to save generated morphology to :type save_to_file: str :param plot_type: type of plot, one of: + - Detailed: show detailed morphology taking into account each segment's width - Constant: show morphology, but use constant line widths - - Schematic: only plot each unbranched segment group as a straight - line, not following each segment + :type plot_type: str """ - if plot_type not in ["Detailed", "Constant", "Schematic"]: - raise ValueError("plot_type must be one of 'Detailed', 'Constant', or 'Schematic'") + if plot_type not in ["Detailed", "Constant"]: + raise ValueError( + "plot_type must be one of 'Detailed', 'Constant', or 'Schematic'" + ) nml_model = read_neuroml2_file(nml_file) @@ -661,5 +639,1223 @@ def plot_interactive_3D( logger.info("Saved image to %s of plot: %s" % (save_to_file, title)) +def plot_2D_cell_morphology( + offset: typing.List[float] = [0, 0], + cell: Cell = None, + plane2d: str = "xy", + color: typing.Optional[str] = None, + title: str = "", + verbose: bool = False, + fig: matplotlib.figure.Figure = None, + ax: matplotlib.axes.Axes = None, + min_width: float = DEFAULTS["minWidth"], + axis_min_max: typing.List = [float("inf"), -1 * float("inf")], + scalebar: bool = False, + nogui: bool = True, + autoscale: bool = True, + square: bool = False, + plot_type: str = "Detailed", + save_to_file: typing.Optional[str] = None, + close_plot: bool = False, + overlay_data: typing.Dict[int, float] = None, + overlay_data_label: typing.Optional[str] = None, + datamin: typing.Optional[float] = None, + datamax: typing.Optional[float] = None, + colormap_name: str = "viridis", +): + """Plot the detailed 2D morphology of a cell in provided plane. + + The method can also overlay data onto the morphology. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :py:func:`plot_2D` + general function for plotting + + :py:func:`plot_2D_schematic` + for plotting only segmeng groups with their labels + + :py:func:`plot_2D_point_cells` + for plotting point cells + + :param offset: offset for cell + :type offset: [float, float] + :param cell: cell to plot + :type cell: neuroml.Cell + :param plane2d: plane to plot on + :type plane2d: str + :param color: color to use for all segments + :type color: str + :param fig: a matplotlib.figure.Figure object to use + :type fig: matplotlib.figure.Figure + :param ax: a matplotlib.axes.Axes object to use + :type ax: matplotlib.axes.Axes + :param min_width: minimum width for segments (useful for visualising very + thin segments): default 0.8um + :type min_width: float + :param axis_min_max: min, max value of axes + :type axis_min_max: [float, float] + :param title: title of plot + :type title: str + :param verbose: show extra information (default: False) + :type verbose: bool + :param nogui: do not show matplotlib GUI (default: false) + :type nogui: bool + :param save_to_file: optional filename to save generated morphology to + :type save_to_file: str + :param square: scale axes so that image is approximately square + :type square: bool + :param autoscale: toggle autoscaling + :type autoscale: bool + :param scalebar: toggle scalebar + :type scalebar: bool + :param close_plot: call pyplot.close() to close plot after plotting + :type close_plot: bool + :param overlay_data: data to overlay over the morphology + this must be a dictionary with segment ids as keys, the single value to + overlay as values + :type overlay_data: dict, keys are segment ids, values are magnitudes to + overlay on curtain plots + :param overlay_data_label: label of data being overlaid + :type overlay_data_label: str + :param colormap_name: name of matplotlib colourmap to use for data overlay + See: + https://matplotlib.org/stable/api/matplotlib_configuration_api.html#matplotlib.colormaps + Note: random colours are used for each segment if no data is to be overlaid + :type colormap_name: str + :param datamin: min limits of data (useful to compare different plots) + :type datamin: float + :param datamax: max limits of data (useful to compare different plots) + :type datamax: float + + :raises: ValueError if `cell` is None + + """ + if cell is None: + raise ValueError( + "No cell provided. If you would like to plot a network of point neurons, consider using `plot_2D_point_cells` instead" + ) + + try: + soma_segs = cell.get_all_segments_in_group("soma_group") + except Exception: + soma_segs = [] + try: + dend_segs = cell.get_all_segments_in_group("dendrite_group") + except Exception: + dend_segs = [] + try: + axon_segs = cell.get_all_segments_in_group("axon_group") + except Exception: + axon_segs = [] + + if fig is None: + fig, ax = get_new_matplotlib_morph_plot(title) + + # overlaying data + data_max = -1 * float("inf") + data_min = float("inf") + acolormap = None + norm = None + if overlay_data: + this_max = numpy.max(list(overlay_data.values())) + this_min = numpy.min(list(overlay_data.values())) + if this_max > data_max: + data_max = this_max + if this_min < data_min: + data_min = this_min + + if datamin is not None: + data_min = datamin + if datamax is not None: + data_max = datamax + + acolormap = matplotlib.colormaps[colormap_name] + norm = matplotlib.colors.Normalize(vmin=data_min, vmax=data_max) + fig.colorbar( + matplotlib.cm.ScalarMappable(norm=norm, cmap=acolormap), + label=overlay_data_label, + ) + + # random default color + for seg in cell.morphology.segments: + p = cell.get_actual_proximal(seg.id) + d = seg.distal + width = (p.diameter + d.diameter) / 2 + + if width < min_width: + width = min_width + + if plot_type == "Constant": + width = min_width + + if overlay_data and acolormap and norm: + try: + seg_color = acolormap(norm(overlay_data[seg.id])) + except KeyError: + seg_color = "black" + else: + seg_color = "b" + if seg.id in soma_segs: + seg_color = "g" + elif seg.id in axon_segs: + seg_color = "r" + + spherical = ( + p.x == d.x and p.y == d.y and p.z == d.z and p.diameter == d.diameter + ) + + if verbose: + logger.info( + "\nSeg %s, id: %s%s has proximal: %s, distal: %s (width: %s, min_width: %s), color: %s" + % ( + seg.name, + seg.id, + " (spherical)" if spherical else "", + p, + d, + width, + min_width, + str(seg_color), + ) + ) + + if plane2d == "xy": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0] + p.x, offset[0] + d.x], + [offset[1] + p.y, offset[1] + d.y], + width, + seg_color if color is None else color, + axis_min_max, + ) + elif plane2d == "yx": + add_line_to_matplotlib_2D_plot( + ax, + [offset[1] + p.y, offset[1] + d.y], + [offset[0] + p.x, offset[0] + d.x], + width, + seg_color if color is None else color, + axis_min_max, + ) + elif plane2d == "xz": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0] + p.x, offset[0] + d.x], + [offset[2] + p.z, offset[2] + d.z], + width, + seg_color if color is None else color, + axis_min_max, + ) + elif plane2d == "zx": + add_line_to_matplotlib_2D_plot( + ax, + [offset[2] + p.z, offset[2] + d.z], + [offset[0] + p.x, offset[0] + d.x], + width, + seg_color if color is None else color, + axis_min_max, + ) + elif plane2d == "yz": + add_line_to_matplotlib_2D_plot( + ax, + [offset[1] + p.y, offset[1] + d.y], + [offset[2] + p.z, offset[2] + d.z], + width, + seg_color if color is None else color, + axis_min_max, + ) + elif plane2d == "zy": + add_line_to_matplotlib_2D_plot( + ax, + [offset[2] + p.z, offset[2] + d.z], + [offset[1] + p.y, offset[1] + d.y], + width, + seg_color if color is None else color, + axis_min_max, + ) + else: + raise Exception(f"Invalid value for plane: {plane2d}") + + if verbose: + print("Extent x: %s -> %s" % (axis_min_max[0], axis_min_max[1])) + + if scalebar: + add_scalebar_to_matplotlib_plot(axis_min_max, ax) + if autoscale: + autoscale_matplotlib_plot(verbose, square) + + if save_to_file: + abs_file = os.path.abspath(save_to_file) + plt.savefig(abs_file, dpi=200, bbox_inches="tight") + print(f"Saved image on plane {plane2d} to {abs_file} of plot: {title}") + + if not nogui: + plt.show() + if close_plot: + logger.info("closing plot") + plt.close() + + +def plot_3D_cell_morphology( + offset: typing.List[float] = [0, 0, 0], + cell: Cell = None, + color: typing.Optional[str] = None, + title: str = "", + verbose: bool = False, + current_scene: scene.SceneCanvas = None, + current_view: scene.ViewBox = None, + min_width: float = DEFAULTS["minWidth"], + axis_min_max: typing.List = [float("inf"), -1 * float("inf")], + nogui: bool = True, + plot_type: str = "Constant", + theme="light", +): + """Plot the detailed 3D morphology of a cell using vispy. + https://vispy.org/ + + .. versionadded:: 1.0.0 + + .. seealso:: + + :py:func:`plot_2D` + general function for plotting + + :py:func:`plot_2D_schematic` + for plotting only segmeng groups with their labels + + :py:func:`plot_2D_point_cells` + for plotting point cells + + :param offset: offset for cell + :type offset: [float, float] + :param cell: cell to plot + :type cell: neuroml.Cell + :param color: color to use for segments: + + - if None, each segment is given a new unique color + - if "Groups", each unbranched segment group is given a unique color, + and segments that do not belong to an unbranched segment group are in + white + - if "Default Groups", axonal segments are in red, dendritic in blue, + somatic in green, and others in white + + :type color: str + :param min_width: minimum width for segments (useful for visualising very + :type min_width: float + :param axis_min_max: min, max value of axes + :type axis_min_max: [float, float] + :param title: title of plot + :type title: str + :param verbose: show extra information (default: False) + :type verbose: bool + :param nogui: do not show image immediately + :type nogui: bool + :param current_scene: vispy SceneCanvas to use (a new one is created if it is not + provided) + :type current_scene: SceneCanvas + :param current_view: vispy viewbox to use + :type current_view: ViewBox + :param plot_type: type of plot, one of: + + - "Detailed": show detailed morphology taking into account each segment's + width. This is not performant, because a new line is required for + each segment. To only be used for cells with small numbers of + segments + - "Constant": show morphology, but use constant line widths + + This is only applicable for neuroml.Cell cells (ones with some + morphology) + + :type plot_type: str + :param theme: theme to use (dark/light) + :type theme: str + :raises: ValueError if `cell` is None + + """ + if cell is None: + raise ValueError( + "No cell provided. If you would like to plot a network of point neurons, consider using `plot_2D_point_cells` instead" + ) + + try: + soma_segs = cell.get_all_segments_in_group("soma_group") + except Exception: + soma_segs = [] + try: + dend_segs = cell.get_all_segments_in_group("dendrite_group") + except Exception: + dend_segs = [] + try: + axon_segs = cell.get_all_segments_in_group("axon_group") + except Exception: + axon_segs = [] + + if current_scene is None or current_view is None: + view_min, view_max = get_cell_bound_box(cell) + current_scene, current_view = create_new_vispy_canvas( + view_min, view_max, title, theme=theme + ) + + if color == "Groups": + color_dict = {} + # if no segment groups are given, do them all + segment_groups = [] + for sg in cell.morphology.segment_groups: + if sg.neuro_lex_id == neuro_lex_ids["section"]: + segment_groups.append(sg.id) + + ord_segs = cell.get_ordered_segments_in_groups( + segment_groups, check_parentage=False + ) + + for sgs, segs in ord_segs.items(): + c = get_next_hex_color() + for s in segs: + color_dict[s.id] = c + + # for lines/segments + points = [] + toconnect = [] + colors = [] + # for any spheres which we plot as markers at once + marker_points = [] + marker_colors = [] + marker_sizes = [] + + for seg in cell.morphology.segments: + p = cell.get_actual_proximal(seg.id) + d = seg.distal + width = (p.diameter + d.diameter) / 2 + + if width < min_width: + width = min_width + + if plot_type == "Constant": + width = min_width + + seg_color = "white" + if color is None: + seg_color = get_next_hex_color() + elif color == "Groups": + try: + seg_color = color_dict[seg.id] + except KeyError: + print(f"Unbranched segment found: {seg.id}") + if seg.id in soma_segs: + seg_color = "green" + elif seg.id in axon_segs: + seg_color = "red" + elif seg.id in dend_segs: + seg_color = "blue" + elif color == "Default Groups": + if seg.id in soma_segs: + seg_color = "green" + elif seg.id in axon_segs: + seg_color = "red" + elif seg.id in dend_segs: + seg_color = "blue" + else: + seg_color = color + + # check if for a spherical segment, add extra spherical node + if p.x == d.x and p.y == d.y and p.z == d.z and p.diameter == d.diameter: + marker_points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) + marker_colors.append(seg_color) + marker_sizes.append(p.diameter) + + if plot_type == "Constant": + points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) + colors.append(seg_color) + points.append([offset[0] + d.x, offset[1] + d.y, offset[2] + d.z]) + colors.append(seg_color) + toconnect.append([len(points) - 2, len(points) - 1]) + # every segment plotted individually + elif plot_type == "Detailed": + points = [] + toconnect = [] + colors = [] + points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) + colors.append(seg_color) + points.append([offset[0] + d.x, offset[1] + d.y, offset[2] + d.z]) + colors.append(seg_color) + toconnect.append([len(points) - 2, len(points) - 1]) + scene.Line( + pos=points, + color=colors, + connect=numpy.array(toconnect), + parent=current_view.scene, + width=width, + ) + + if plot_type == "Constant": + scene.Line( + pos=points, + color=colors, + connect=numpy.array(toconnect), + parent=current_view.scene, + width=width, + ) + + if not nogui: + # markers + if len(marker_points) > 0: + scene.Markers( + pos=numpy.array(marker_points), + size=numpy.array(marker_sizes), + spherical=True, + face_color=marker_colors, + edge_color=marker_colors, + edge_width=0, + parent=current_view.scene, + scaling=True, + antialias=0, + ) + app.run() + return marker_points, marker_sizes, marker_colors + + +def plot_2D_point_cells( + offset: typing.List[float] = [0, 0], + plane2d: str = "xy", + color: typing.Optional[str] = None, + soma_radius: float = 10.0, + title: str = "", + verbose: bool = False, + fig: matplotlib.figure.Figure = None, + ax: matplotlib.axes.Axes = None, + axis_min_max: typing.List = [float("inf"), -1 * float("inf")], + scalebar: bool = False, + nogui: bool = True, + autoscale: bool = True, + square: bool = False, + save_to_file: typing.Optional[str] = None, + close_plot: bool = False, +): + """Plot point cells. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :py:func:`plot_2D` + general function for plotting + + :py:func:`plot_2D_schematic` + for plotting only segmeng groups with their labels + + :py:func:`plot_2D_cell_morphology` + for plotting cells with detailed morphologies + + :param offset: location of cell + :type offset: [float, float] + :param plane2d: plane to plot on + :type plane2d: str + :param color: color to use for cell + :type color: str + :param soma_radius: radius of soma + :type soma_radius: float + :param fig: a matplotlib.figure.Figure object to use + :type fig: matplotlib.figure.Figure + :param ax: a matplotlib.axes.Axes object to use + :type ax: matplotlib.axes.Axes + :param axis_min_max: min, max value of axes + :type axis_min_max: [float, float] + :param title: title of plot + :type title: str + :param verbose: show extra information (default: False) + :type verbose: bool + :param nogui: do not show matplotlib GUI (default: false) + :type nogui: bool + :param save_to_file: optional filename to save generated morphology to + :type save_to_file: str + :param square: scale axes so that image is approximately square + :type square: bool + :param autoscale: toggle autoscaling + :type autoscale: bool + :param scalebar: toggle scalebar + :type scalebar: bool + :param close_plot: call pyplot.close() to close plot after plotting + :type close_plot: bool + """ + if fig is None: + fig, ax = get_new_matplotlib_morph_plot(title) + + cell_color = get_next_hex_color() + + if plane2d == "xy": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0], offset[0]], + [offset[1], offset[1]], + soma_radius, + cell_color if color is None else color, + axis_min_max, + ) + elif plane2d == "yx": + add_line_to_matplotlib_2D_plot( + ax, + [offset[1], offset[1]], + [offset[0], offset[0]], + soma_radius, + cell_color if color is None else color, + axis_min_max, + ) + elif plane2d == "xz": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0], offset[0]], + [offset[2], offset[2]], + soma_radius, + cell_color if color is None else color, + axis_min_max, + ) + elif plane2d == "zx": + add_line_to_matplotlib_2D_plot( + ax, + [offset[2], offset[2]], + [offset[0], offset[0]], + soma_radius, + cell_color if color is None else color, + axis_min_max, + ) + elif plane2d == "yz": + add_line_to_matplotlib_2D_plot( + ax, + [offset[1], offset[1]], + [offset[2], offset[2]], + soma_radius, + cell_color if color is None else color, + axis_min_max, + ) + elif plane2d == "zy": + add_line_to_matplotlib_2D_plot( + ax, + [offset[2], offset[2]], + [offset[1], offset[1]], + soma_radius, + cell_color if color is None else color, + axis_min_max, + ) + else: + raise Exception(f"Invalid value for plane: {plane2d}") + + if scalebar: + add_scalebar_to_matplotlib_plot(axis_min_max, ax) + if autoscale: + autoscale_matplotlib_plot(verbose, square) + + if save_to_file: + abs_file = os.path.abspath(save_to_file) + plt.savefig(abs_file, dpi=200, bbox_inches="tight") + print(f"Saved image on plane {plane2d} to {abs_file} of plot: {title}") + + if not nogui: + plt.show() + if close_plot: + logger.info("closing plot") + plt.close() + + +def plot_2D_schematic( + cell: Cell, + segment_groups: typing.Optional[typing.List[SegmentGroup]], + offset: typing.List[float] = [0, 0], + labels: bool = False, + plane2d: str = "xy", + width: float = 2.0, + verbose: bool = False, + square: bool = False, + nogui: bool = False, + save_to_file: typing.Optional[str] = None, + scalebar: bool = True, + autoscale: bool = True, + fig: matplotlib.figure.Figure = None, + ax: matplotlib.axes.Axes = None, + title: str = "", + close_plot: bool = False, +) -> None: + """Plot a 2D schematic of the provided segment groups. + + This plots each segment group as a straight line between its first and last + segment. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :py:func:`plot_2D` + general function for plotting + + :py:func:`plot_2D_point_cells` + for plotting point cells + + :py:func:`plot_2D_cell_morphology` + for plotting cells with detailed morphologies + + :param offset: offset for cell + :type offset: [float, float] + :param cell: cell to plot + :type cell: neuroml.Cell + :param segment_groups: list of unbranched segment groups to plot + :type segment_groups: list(SegmentGroup) + :param labels: toggle labelling of segment groups + :type labels: bool + :param plane2d: what plane to plot (xy/yx/yz/zy/zx/xz) + :type plane2d: str + :param width: width for lines + :type width: float + :param verbose: show extra information (default: False) + :type verbose: bool + :param square: scale axes so that image is approximately square + :type square: bool + :param nogui: do not show matplotlib GUI (default: false) + :type nogui: bool + :param save_to_file: optional filename to save generated morphology to + :type save_to_file: str + :param fig: a matplotlib.figure.Figure object to use + :type fig: matplotlib.figure.Figure + :param ax: a matplotlib.axes.Axes object to use + :type ax: matplotlib.axes.Axes + :param title: title of plot + :type title: str + :param square: scale axes so that image is approximately square + :type square: bool + :param autoscale: toggle autoscaling + :type autoscale: bool + :param scalebar: toggle scalebar + :type scalebar: bool + :param close_plot: call pyplot.close() to close plot after plotting + :type close_plot: bool + + """ + if title == "": + title = f"2D schematic of segment groups from {cell.id}" + + # if no segment groups are given, do them all + if segment_groups is None: + segment_groups = [] + for sg in cell.morphology.segment_groups: + if sg.neuro_lex_id == neuro_lex_ids["section"]: + segment_groups.append(sg.id) + + ord_segs = cell.get_ordered_segments_in_groups( + segment_groups, check_parentage=False + ) + + if fig is None: + logger.debug("No figure provided, creating new fig and ax") + fig, ax = get_new_matplotlib_morph_plot(title, plane2d) + + if plane2d == "xy": + ax.set_xlabel("x (μm)") + ax.set_ylabel("y (μm)") + elif plane2d == "yx": + ax.set_xlabel("y (μm)") + ax.set_ylabel("x (μm)") + elif plane2d == "xz": + ax.set_xlabel("x (μm)") + ax.set_ylabel("z (μm)") + elif plane2d == "zx": + ax.set_xlabel("z (μm)") + ax.set_ylabel("x (μm)") + elif plane2d == "yz": + ax.set_xlabel("y (μm)") + ax.set_ylabel("z (μm)") + elif plane2d == "zy": + ax.set_xlabel("z (μm)") + ax.set_ylabel("y (μm)") + else: + logger.error(f"Invalid value for plane: {plane2d}") + sys.exit(-1) + + # use a mutable object so it can be passed as an argument to methods, using + # float (immuatable) variables requires us to return these from all methods + axis_min_max = [float("inf"), -1 * float("inf")] + width = 1 + + for sgid, segs in ord_segs.items(): + sgobj = cell.get_segment_group(sgid) + if sgobj.neuro_lex_id != neuro_lex_ids["section"]: + raise ValueError( + f"{sgobj} does not have neuro_lex_id set to indicate it is an unbranched segment" + ) + + # get proximal and distal points + first_seg = segs[0] # type: Segment + last_seg = segs[-1] # type: Segment + + # unique color for each segment group + color = get_next_hex_color() + + if plane2d == "xy": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.x, offset[0] + last_seg.distal.x], + [offset[1] + first_seg.proximal.y, offset[1] + last_seg.distal.y], + width, + color, + axis_min_max, + ) + if labels: + add_text_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.x, offset[0] + last_seg.distal.x], + [offset[1] + first_seg.proximal.y, offset[1] + last_seg.distal.y], + color=color, + text=sgid, + ) + + elif plane2d == "yx": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.y, offset[0] + last_seg.distal.y], + [offset[1] + first_seg.proximal.x, offset[1] + last_seg.distal.x], + width, + color, + axis_min_max, + ) + if labels: + add_text_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.y, offset[0] + last_seg.distal.y], + [offset[1] + first_seg.proximal.x, offset[1] + last_seg.distal.x], + color=color, + text=sgid, + ) + elif plane2d == "xz": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.x, offset[0] + last_seg.distal.x], + [offset[1] + first_seg.proximal.z, offset[1] + last_seg.distal.z], + width, + color, + axis_min_max, + ) + if labels: + add_text_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.x, offset[0] + last_seg.distal.x], + [offset[1] + first_seg.proximal.z, offset[1] + last_seg.distal.z], + color=color, + text=sgid, + ) + elif plane2d == "zx": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.z, offset[0] + last_seg.distal.z], + [offset[1] + first_seg.proximal.x, offset[1] + last_seg.distal.x], + width, + color, + axis_min_max, + ) + if labels: + add_text_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.z, offset[0] + last_seg.distal.z], + [offset[1] + first_seg.proximal.x, offset[1] + last_seg.distal.x], + color=color, + text=sgid, + ) + elif plane2d == "yz": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.y, offset[0] + last_seg.distal.y], + [offset[1] + first_seg.proximal.z, offset[1] + last_seg.distal.z], + width, + color, + axis_min_max, + ) + if labels: + add_text_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.y, offset[0] + last_seg.distal.y], + [offset[1] + first_seg.proximal.z, offset[1] + last_seg.distal.z], + color=color, + text=sgid, + ) + elif plane2d == "zy": + add_line_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.z, offset[0] + last_seg.distal.z], + [offset[1] + first_seg.proximal.y, offset[1] + last_seg.distal.y], + width, + color, + axis_min_max, + ) + if labels: + add_text_to_matplotlib_2D_plot( + ax, + [offset[0] + first_seg.proximal.z, offset[0] + last_seg.distal.z], + [offset[1] + first_seg.proximal.y, offset[1] + last_seg.distal.y], + color=color, + text=sgid, + ) + else: + raise Exception(f"Invalid value for plane: {plane2d}") + + if verbose: + print("Extent x: %s -> %s" % (axis_min_max[0], axis_min_max[1])) + + if scalebar: + add_scalebar_to_matplotlib_plot(axis_min_max, ax) + if autoscale: + autoscale_matplotlib_plot(verbose, square) + + if save_to_file: + abs_file = os.path.abspath(save_to_file) + plt.savefig(abs_file, dpi=200, bbox_inches="tight") + print(f"Saved image on plane {plane2d} to {abs_file} of plot: {title}") + + if not nogui: + plt.show() + if close_plot: + logger.info("closing plot") + plt.close() + + +def plot_3D_schematic( + cell: Cell, + segment_groups: typing.Optional[typing.List[SegmentGroup]], + offset: typing.List[float] = [0, 0, 0], + labels: bool = False, + width: float = 5.0, + verbose: bool = False, + nogui: bool = False, + title: str = "", + current_scene: scene.SceneCanvas = None, + current_view: scene.ViewBox = None, + theme: str = "light", +) -> None: + """Plot a 3D schematic of the provided segment groups in Napari as a new + layer.. + + This plots each segment group as a straight line between its first and last + segment. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :py:func:`plot_2D_schematic` + general function for plotting + + :py:func:`plot_2D` + general function for plotting + + :py:func:`plot_2D_point_cells` + for plotting point cells + + :py:func:`plot_2D_cell_morphology` + for plotting cells with detailed morphologies + + :param offset: offset for cell + :type offset: [float, float, float] + :param cell: cell to plot + :type cell: neuroml.Cell + :param segment_groups: list of unbranched segment groups to plot + :type segment_groups: list(SegmentGroup) + :param labels: toggle labelling of segment groups + :type labels: bool + :param width: width for lines for segment groups + :type width: float + :param verbose: show extra information (default: False) + :type verbose: bool + :param title: title of plot + :type title: str + :param nogui: toggle if plot should be shown or not + :type nogui: bool + :param current_scene: vispy SceneCanvas to use (a new one is created if it is not + provided) + :type current_scene: SceneCanvas + :param current_view: vispy viewbox to use + :type current_view: ViewBox + :param theme: theme to use (light/dark) + :type theme: str + """ + if title == "": + title = f"3D schematic of segment groups from {cell.id}" + + # if no segment groups are given, do them all + if segment_groups is None: + segment_groups = [] + for sg in cell.morphology.segment_groups: + if sg.neuro_lex_id == neuro_lex_ids["section"]: + segment_groups.append(sg.id) + + ord_segs = cell.get_ordered_segments_in_groups( + segment_groups, check_parentage=False + ) + + # if no canvas is defined, define a new one + if current_scene is None or current_view is None: + view_min, view_max = get_cell_bound_box(cell) + current_scene, current_view = create_new_vispy_canvas( + view_min, view_max, title, theme=theme + ) + + points = [] + toconnect = [] + colors = [] + text = [] + textpoints = [] + + for sgid, segs in ord_segs.items(): + sgobj = cell.get_segment_group(sgid) + if sgobj.neuro_lex_id != neuro_lex_ids["section"]: + raise ValueError( + f"{sgobj} does not have neuro_lex_id set to indicate it is an unbranched segment" + ) + + # get proximal and distal points + first_seg = segs[0] # type: Segment + last_seg = segs[-1] # type: Segment + first_prox = cell.get_actual_proximal(first_seg.id) + + points.append( + [ + offset[0] + first_prox.x, + offset[1] + first_prox.y, + offset[2] + first_prox.z, + ] + ) + points.append( + [ + offset[0] + last_seg.distal.x, + offset[1] + last_seg.distal.y, + offset[2] + last_seg.distal.z, + ] + ) + colors.append(get_next_hex_color()) + colors.append(get_next_hex_color()) + toconnect.append([len(points) - 2, len(points) - 1]) + + # TODO: needs fixing to show labels + if labels: + text.append(f"{sgid}") + textpoints.append( + [ + offset[0] + (first_prox.x + last_seg.distal.x) / 2, + offset[1] + (first_prox.y + last_seg.distal.y) / 2, + offset[2] + (first_prox.z + last_seg.distal.z) / 2, + ] + ) + """ + + alabel = add_text_to_vispy_3D_plot(current_scene=current_view.scene, text=f"{sgid}", + xv=[offset[0] + first_seg.proximal.x, offset[0] + last_seg.distal.x], + yv=[offset[0] + first_seg.proximal.y, offset[0] + last_seg.distal.y], + zv=[offset[1] + first_seg.proximal.z, offset[1] + last_seg.distal.z], + color=colors[-1]) + alabel.font_size = 30 + """ + + scene.Line( + points, + parent=current_view.scene, + color=colors, + width=width, + connect=numpy.array(toconnect), + ) + if labels: + print("Text rendering") + scene.Text( + text, pos=textpoints, font_size=30, color="black", parent=current_view.scene + ) + + if not nogui: + app.run() + + +def plot_segment_groups_curtain_plots( + cell: Cell, + segment_groups: typing.List[SegmentGroup], + labels: bool = False, + verbose: bool = False, + nogui: bool = False, + save_to_file: typing.Optional[str] = None, + overlay_data: typing.Dict[str, typing.List[typing.Any]] = None, + overlay_data_label: str = "", + width: typing.Union[float, int] = 4, + colormap_name: str = "viridis", + title: str = "SegmentGroup", + datamin: typing.Optional[float] = None, + datamax: typing.Optional[float] = None, + close_plot: bool = False, +) -> None: + """Plot curtain plots of provided segment groups. + + .. versionadded:: 1.0.0 + + :param cell: cell to plot + :type cell: neuroml.Cell + :param segment_groups: list of unbranched segment groups to plot + :type segment_groups: list(SegmentGroup) + :param labels: toggle labelling of segment groups + :type labels: bool + :param verbose: show extra information (default: False) + :type verbose: bool + :param nogui: do not show matplotlib GUI (default: false) + :type nogui: bool + :param save_to_file: optional filename to save generated morphology to + :type save_to_file: str + :param overlay_data: data to overlay over the curtain plots; + this must be a dictionary with segment group ids as keys, and lists of + values to overlay as values. Each list should have a value for every + segment in the segment group. + :type overlay_data: dict, keys are segment group ids, values are lists of + magnitudes to overlay on curtain plots + :param overlay_data_label: label of data being overlaid + :type overlay_data_label: str + :param width: width of each segment group + :type width: float/int + :param colormap_name: name of matplotlib colourmap to use for data overlay + See: + https://matplotlib.org/stable/api/matplotlib_configuration_api.html#matplotlib.colormaps + Note: random colours are used for each segment if no data is to be overlaid + :type colormap_name: str + :param title: plot title, displayed at bottom + :type title: str + :param datamin: min limits of data (useful to compare different plots) + :type datamin: float + :param datamax: max limits of data (useful to compare different plots) + :type datamax: float + :param close_plot: call pyplot.close() to close plot after plotting + :type close_plot: bool + :returns: None + + :raises ValueError: if keys in `overlay_data` do not match + ids of segment groups of `segment_groups` + :raises ValueError: if number of items for each key in `overlay_data` does + not match the number of segments in the corresponding segment group + """ + # use a random number generator so that the colours are always the same + myrandom = random.Random() + myrandom.seed(122436) + + (ord_segs, cumulative_lengths) = cell.get_ordered_segments_in_groups( + segment_groups, check_parentage=False, include_cumulative_lengths=True + ) + + # plot setup + fig, ax = plt.subplots(1, 1) # noqa + plt.get_current_fig_manager().set_window_title(title) + + # overlaying data related checks + data_max = -1 * float("inf") + data_min = float("inf") + acolormap = None + norm = None + if overlay_data: + if set(overlay_data.keys()) != set(ord_segs.keys()): + raise ValueError( + f"Keys of overlay_data ({overlay_data.keys()}) and ord_segs ({ord_segs.keys()})must match." + ) + for key in overlay_data.keys(): + if len(overlay_data[key]) != len(ord_segs[key]): + raise ValueError( + f"Number of values for key {key} does not match in overlay_data({len(overlay_data[key])}) and the segment group ({len(ord_segs[key])})" + ) + + # since lists are of different lengths, one cannot use `numpy.max` + # on all the values directly + this_max = numpy.max(list(overlay_data[key])) + this_min = numpy.min(list(overlay_data[key])) + if this_max > data_max: + data_max = this_max + if this_min < data_min: + data_min = this_min + + if datamin is not None: + data_min = datamin + if datamax is not None: + data_max = datamax + + acolormap = matplotlib.colormaps[colormap_name] + norm = matplotlib.colors.Normalize(vmin=data_min, vmax=data_max) + fig.colorbar( + matplotlib.cm.ScalarMappable(norm=norm, cmap=acolormap), + label=overlay_data_label, + ) + + ax.spines["right"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.yaxis.set_ticks_position("left") + ax.xaxis.set_ticks_position("none") + ax.xaxis.set_ticks([]) + + ax.set_xlabel(title) + ax.set_ylabel("length (μm)") + + # column counter + column = 0 + for sgid, segs in ord_segs.items(): + column += 1 + length = 0 + + cumulative_lengths_sg = cumulative_lengths[sgid] + + sgobj = cell.get_segment_group(sgid) + if sgobj.neuro_lex_id != neuro_lex_ids["section"]: + raise ValueError( + f"{sgobj} does not have neuro_lex_id set to indicate it is an unbranched segment" + ) + + for seg_num in range(0, len(segs)): + seg = segs[seg_num] + cumulative_len = cumulative_lengths_sg[seg_num] + + if overlay_data and acolormap and norm: + color = acolormap(norm(overlay_data[sgid][seg_num])) + else: + color = get_next_hex_color(myrandom) + + logger.debug(f"color is {color}") + + add_box_to_matplotlib_2D_plot( + ax, + [column * width - width * 0.10, -1 * length], + height=cumulative_len, + width=width * 0.8, + color=color, + ) + + length += cumulative_len + + if labels: + add_text_to_matplotlib_2D_plot( + ax, + [column * width + width / 2, column * width + width / 2], + [50, 100], + color="black", + text=sgid, + vertical="bottom", + horizontal="center", + clip_on=False, + ) + + plt.autoscale() + xl = plt.xlim() + yl = plt.ylim() + if verbose: + print("Auto limits - x: %s , y: %s" % (xl, yl)) + + plt.ylim(top=0) + ax.set_yticklabels(abs(ax.get_yticks())) + + if save_to_file: + abs_file = os.path.abspath(save_to_file) + plt.savefig(abs_file, dpi=200, bbox_inches="tight") + print(f"Saved image to {abs_file} of plot: {title}") + + if not nogui: + plt.show() + if close_plot: + logger.info("Closing plot") + plt.close() + + if __name__ == "__main__": main() diff --git a/pyneuroml/povray/NeuroML2ToPOVRay.py b/pyneuroml/povray/NeuroML2ToPOVRay.py index 2c948b47..3554d3a8 100644 --- a/pyneuroml/povray/NeuroML2ToPOVRay.py +++ b/pyneuroml/povray/NeuroML2ToPOVRay.py @@ -1,13 +1,15 @@ -# -# A file for converting NeuroML 2 files (including cells & network structure) -# into POVRay files for 3D rendering -# -# Author: Padraig Gleeson & Matteo Farinella -# -# This file has been developed as part of the neuroConstruct project -# This work has been funded by the Medical Research Council and Wellcome Trust +""" +A file for converting NeuroML 2 files (including cells & network structure) +into POVRay files for 3D rendering + +Author: Padraig Gleeson & Matteo Farinella + +This file has been developed as part of the neuroConstruct project +This work has been funded by the Medical Research Council and Wellcome Trust +""" +import typing import random import argparse import logging @@ -17,9 +19,9 @@ logger = logging.getLogger(__name__) -_WHITE = "<1,1,1,0.55>" -_BLACK = "<0,0,0,0.55>" -_GREY = "<0.85,0.85,0.85,0.55>" +_WHITE = "<1,1,1,0.55>" # type: str +_BLACK = "<0,0,0,0.55>" # type: str +_GREY = "<0.85,0.85,0.85,0.55>" # type: str _DUMMY_CELL = "DUMMY_CELL" @@ -49,7 +51,7 @@ "mindiam": 0, "plane": False, "segids": False, -} +} # type: typing.Dict[str, typing.Any] def process_args(): @@ -258,28 +260,80 @@ def main(): def generate_povray( - neuroml_file, - split=defaults["split"], - background=defaults["background"], - movie=defaults["movie"], - inputs=defaults["inputs"], - conns=defaults["conns"], - conn_points=defaults["conn_points"], - v=defaults["v"], - frames=defaults["frames"], - posx=defaults["posx"], - posy=defaults["posy"], - posz=defaults["posz"], - viewx=defaults["viewx"], - viewy=defaults["viewy"], - viewz=defaults["viewz"], - scalex=defaults["scalex"], - scaley=defaults["scaley"], - scalez=defaults["scalez"], - mindiam=defaults["mindiam"], - plane=defaults["plane"], - segids=defaults["segids"], + neuroml_file: str, + split: bool = defaults["split"], + background: str = defaults["background"], + movie: bool = defaults["movie"], + inputs: bool = defaults["inputs"], + conns: bool = defaults["conns"], + conn_points: bool = defaults["conn_points"], + v: bool = defaults["v"], + frames: bool = defaults["frames"], + posx: float = defaults["posx"], + posy: float = defaults["posy"], + posz: float = defaults["posz"], + viewx: float = defaults["viewx"], + viewy: float = defaults["viewy"], + viewz: float = defaults["viewz"], + scalex: float = defaults["scalex"], + scaley: float = defaults["scaley"], + scalez: float = defaults["scalez"], + mindiam: float = defaults["mindiam"], + plane: bool = defaults["plane"], + segids: bool = defaults["segids"], ): + """Generate a POVRAY image or movie file. + + Please see http://www.povray.org/documentation/ and + https://wiki.povray.org/content/Main_Page for information on installing and + using POVRAY. + + This function will generate POVRAY files that you can then run using + POVRAY. + + :param neuroml_file: path to NeuroML file containing cell/network + :type neuroml_file: str + :param split: generate separate files for cells and network + :type split: bool + :param background: background for POVRAY rendering + :type background: str + :param movie: toggle between image and movie rendering + :type movie: bool + :param inputs: show locations of inputs also + :type inputs: bool + :param conns: show connections in networks with lines + :type conns: bool + :param conn_points: show end points of connections in network + :type conn_points: bool + :param v: toggle verbose output + :type v: bool + :param frames: number of frames to use in movie + :type frames: int + :param posx: offset position in x dir (0 is centre, 1 is top) + :type posx: float + :param posy: offset position in y dir (0 is centre, 1 is top) + :type posy: float + :param posz: offset position in z dir (0 is centre, 1 is top) + :type posz: float + :param viewx: offset viewing point in x dir (0 is centre, 1 is top) + :type viewx: float + :param viewy: offset viewing point in y dir (0 is centre, 1 is top) + :type viewy: float + :param viewz: offset viewing point in z dir (0 is centre, 1 is top) + :type viewz: float + :param scalex: scale position from network in x dir + :type scalex: float + :param scaley: scale position from network in y dir + :type scaley: float + :param scalez: scale position from network in z dir + :type scalez: float + :param mindiam: minimum diameter for dendrites/axons (to improve visualisations) + :type mindiam: float + :param plane: add a 2D plane below cell/network + :type plane: bool + :param segids: toggle showing segment ids + :type segids: bool + """ xmlfile = neuroml_file pov_file_name = xmlfile @@ -682,7 +736,6 @@ def generate_povray( net_file.write("}\n") if conns or conn_points: - projections = ( nml_doc.networks[0].projections + nml_doc.networks[0].electrical_projections diff --git a/pyneuroml/pynml.py b/pyneuroml/pynml.py index 49e3b571..4ada0353 100644 --- a/pyneuroml/pynml.py +++ b/pyneuroml/pynml.py @@ -1930,22 +1930,6 @@ def reload_saved_data( return traces -def get_next_hex_color(my_random: typing.Optional[random.Random] = None) -> str: - """Get a new randomly generated HEX colour code. - - You may pass a random.Random instance that you may be used. Otherwise the - default Python random generator will be used. - - :param my_random: a random.Random object - :type my_random: random.Random - :returns: HEX colour code - """ - if my_random is not None: - return "#%06x" % my_random.randint(0, 0xFFFFFF) - else: - return "#%06x" % random.randint(0, 0xFFFFFF) - - def confirm_file_exists(filename: str) -> None: """Check if a file exists, exit if it does not. diff --git a/pyneuroml/utils/plot.py b/pyneuroml/utils/plot.py new file mode 100644 index 00000000..ebe4a72e --- /dev/null +++ b/pyneuroml/utils/plot.py @@ -0,0 +1,571 @@ +#!/usr/bin/env python3 +""" +Common utils to help with plotting + +File: pyneuroml/utils/plot.py + +Copyright 2023 NeuroML contributors +""" + +import logging +import textwrap +import numpy +import typing +import random +import matplotlib +from matplotlib import pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle +from matplotlib_scalebar.scalebar import ScaleBar +from vispy import scene +from vispy.scene import SceneCanvas +from vispy.app import Timer +from neuroml import Cell, Segment + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +VISPY_THEME = { + "light": {"bg": "white", "fg": "black"}, + "dark": {"bg": "black", "fg": "white"}, +} +PYNEUROML_VISPY_THEME = "light" + + +def add_text_to_vispy_3D_plot( + current_scene: SceneCanvas, + xv: typing.List[float], + yv: typing.List[float], + zv: typing.List[float], + color: str, + text: str, +): + """Add text to a vispy plot between two points. + + Wrapper around vispy.scene.visuals.Text + + Rotates the text label to ensure it is at the same angle as the line. + + :param scene: vispy scene object + :type scene: SceneCanvas + :param xv: start and end coordinates in one axis + :type xv: list[x1, x2] + :param yv: start and end coordinates in second axis + :type yv: list[y1, y2] + :param zv: start and end coordinates in third axix + :type zv: list[z1, z2] + :param color: color of text + :type color: str + :param text: text to write + :type text: str + """ + angle = int(numpy.rad2deg(numpy.arctan2((yv[1] - yv[0]), (xv[1] - xv[0])))) + if angle > 90: + angle -= 180 + elif angle < -90: + angle += 180 + + return scene.Text( + pos=((xv[0] + xv[1]) / 2, (yv[0] + yv[1]) / 2, (zv[0] + zv[1]) / 2), + text=text, + color=color, + rotation=angle, + parent=current_scene, + ) + + +def add_text_to_matplotlib_2D_plot( + ax: matplotlib.axes.Axes, + xv: typing.List[float], + yv: typing.List[float], + color: str, + text: str, + horizontal: str = "center", + vertical: str = "bottom", + clip_on: bool = True, +): + """Add text to a matplotlib plot between two points + + Wrapper around matplotlib.axes.Axes.text. + + Rotates the text label to ensure it is at the same angle as the line. + + :param ax: matplotlib axis object + :type ax: Axes + :param xv: start and end coordinates in one axis + :type xv: list[x1, x2] + :param yv: start and end coordinates in second axix + :type yv: list[y1, y2] + :param color: color of text + :type color: str + :param text: text to write + :type text: str + :param clip_on: toggle clip_on (if False, text will also be shown outside plot) + :type clip_on: bool + + """ + angle = int(numpy.rad2deg(numpy.arctan2((yv[1] - yv[0]), (xv[1] - xv[0])))) + if angle > 90: + angle -= 180 + elif angle < -90: + angle += 180 + + ax.text( + (xv[0] + xv[1]) / 2, + (yv[0] + yv[1]) / 2, + text, + color=color, + horizontalalignment=horizontal, + verticalalignment=vertical, + rotation_mode="default", + rotation=angle, + clip_on=clip_on, + ) + + +def get_next_hex_color(my_random: typing.Optional[random.Random] = None) -> str: + """Get a new randomly generated HEX colour code. + + You may pass a random.Random instance that you may be used. Otherwise the + default Python random generator will be used. + + :param my_random: a random.Random object + :type my_random: random.Random + :returns: HEX colour code + """ + if my_random is not None: + return "#%06x" % my_random.randint(0, 0xFFFFFF) + else: + return "#%06x" % random.randint(0, 0xFFFFFF) + + +def add_box_to_matplotlib_2D_plot(ax, xy, height, width, color): + """Add a box to a matplotlib plot, at xy of `height`, `width` and `color`. + + :param ax: matplotlib.axes.Axes object + :type ax: matplotlob.axes.Axis + :param xy: bottom left corner of box + :type xy: typing.List[float] + :param height: height of box + :type height: float + :param width: width of box + :type width: float + :param color: color of box for edge, face, fill + :type color: str + :returns: None + + """ + ax.add_patch( + Rectangle(xy, width, height, edgecolor=color, facecolor=color, fill=True) + ) + + +def get_new_matplotlib_morph_plot( + title: str = "", plane2d: str = "xy" +) -> typing.Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: + """Get a new 2D matplotlib plot for morphology related plots. + + :param title: title of plot + :type title: str + :param plane2d: plane to use + :type plane: str + :returns: new [matplotlib.figure.Figure, matplotlib.axes.Axes] + :rtype: [matplotlib.figure.Figure, matplotlib.axes.Axes] + """ + fig, ax = plt.subplots(1, 1) # noqa + plt.get_current_fig_manager().set_window_title(title) + + ax.set_aspect("equal") + + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.yaxis.set_ticks_position("left") + ax.xaxis.set_ticks_position("bottom") + + if plane2d == "xy": + ax.set_xlabel("x (μm)") + ax.set_ylabel("y (μm)") + elif plane2d == "yx": + ax.set_xlabel("y (μm)") + ax.set_ylabel("x (μm)") + elif plane2d == "xz": + ax.set_xlabel("x (μm)") + ax.set_ylabel("z (μm)") + elif plane2d == "zx": + ax.set_xlabel("z (μm)") + ax.set_ylabel("x (μm)") + elif plane2d == "yz": + ax.set_xlabel("y (μm)") + ax.set_ylabel("z (μm)") + elif plane2d == "zy": + ax.set_xlabel("z (μm)") + ax.set_ylabel("y (μm)") + else: + raise ValueError(f"Invalid value for plane: {plane2d}") + + return fig, ax + + +class LineDataUnits(Line2D): + """New Line class for making lines with specific widthS + + Reference: + https://stackoverflow.com/questions/19394505/expand-the-line-with-specified-width-in-data-unit + """ + + def __init__(self, *args, **kwargs): + _lw_data = kwargs.pop("linewidth", 1) + super().__init__(*args, **kwargs) + self._lw_data = _lw_data + + def _get_lw(self): + if self.axes is not None: + ppd = 72.0 / self.axes.figure.dpi + trans = self.axes.transData.transform + return ((trans((1, self._lw_data)) - trans((0, 0))) * ppd)[1] + else: + return 1 + + def _set_lw(self, lw): + self._lw_data = lw + + _linewidth = property(_get_lw, _set_lw) + + +def autoscale_matplotlib_plot(verbose: bool = False, square: bool = True) -> None: + """Autoscale the current matplotlib plot + + :param verbose: toggle verbosity + :type verbose: bool + :param square: toggle squaring of plot + :type square: bool + :returns: None + + """ + plt.autoscale() + xl = plt.xlim() + yl = plt.ylim() + if verbose: + print("Auto limits - x: %s , y: %s" % (xl, yl)) + + small = 0.1 + if xl[1] - xl[0] < small and yl[1] - yl[0] < small: # i.e. only a point + plt.xlim([-100, 100]) + plt.ylim([-100, 100]) + elif xl[1] - xl[0] < small: + d_10 = (yl[1] - yl[0]) / 10 + m = xl[0] + (xl[1] - xl[0]) / 2.0 + plt.xlim([m - d_10, m + d_10]) + elif yl[1] - yl[0] < small: + d_10 = (xl[1] - xl[0]) / 10 + m = yl[0] + (yl[1] - yl[0]) / 2.0 + plt.ylim([m - d_10, m + d_10]) + + if square: + if xl[1] - xl[0] > yl[1] - yl[0]: + d2 = (xl[1] - xl[0]) / 2 + m = yl[0] + (yl[1] - yl[0]) / 2.0 + plt.ylim([m - d2, m + d2]) + + if xl[1] - xl[0] < yl[1] - yl[0]: + d2 = (yl[1] - yl[0]) / 2 + m = xl[0] + (xl[1] - xl[0]) / 2.0 + plt.xlim([m - d2, m + d2]) + + +def add_scalebar_to_matplotlib_plot(axis_min_max, ax): + """Add a scalebar to matplotlib plots. + + The scalebar is of magnitude 50 by default, but if the difference between + max and min vals is less than 100, it's reduced to 5, and if the difference + is less than 10, it's reduced to 1. + + :param axis_min_max: minimum, maximum value in plot + :type axis_min_max: [float, float] + :param ax: axis to plot scalebar at + :type ax: matplotlib.axes.Axes + :returns: None + + """ + # add a scalebar + # ax = fig.add_axes([0, 0, 1, 1]) + sc_val = 50 + if axis_min_max[1] - axis_min_max[0] < 100: + sc_val = 5 + if axis_min_max[1] - axis_min_max[0] < 10: + sc_val = 1 + scalebar1 = ScaleBar( + 0.001, + units="mm", + dimension="si-length", + scale_loc="top", + location="lower right", + fixed_value=sc_val, + fixed_units="um", + box_alpha=0.8, + ) + ax.add_artist(scalebar1) + + +def add_line_to_matplotlib_2D_plot(ax, xv, yv, width, color, axis_min_max): + """Add a line to a matplotlib plot + + :param ax: matplotlib.axes.Axes object + :type ax: matplotlib.axes.Axes + :param xv: x values + :type xv: [float, float] + :param yv: y values + :type yv: [float, float] + :param width: width of line + :type width: float + :param color: color of line + :type color: str + :param axis_min_max: min, max value of axis + :type axis_min_max: [float, float]""" + + if ( + abs(xv[0] - xv[1]) < 0.01 and abs(yv[0] - yv[1]) < 0.01 + ): # looking at the cylinder from the top, OR a sphere, so draw a circle + xv[1] = xv[1] + width / 1000.0 + yv[1] = yv[1] + width / 1000.0 + + ax.add_line( + LineDataUnits(xv, yv, linewidth=width, solid_capstyle="round", color=color) + ) + + ax.add_line( + LineDataUnits(xv, yv, linewidth=width, solid_capstyle="butt", color=color) + ) + + axis_min_max[0] = min(axis_min_max[0], xv[0]) + axis_min_max[0] = min(axis_min_max[0], xv[1]) + axis_min_max[1] = max(axis_min_max[1], xv[0]) + axis_min_max[1] = max(axis_min_max[1], xv[1]) + + +def create_new_vispy_canvas( + view_min: typing.Optional[typing.List[float]] = None, + view_max: typing.Optional[typing.List[float]] = None, + title: str = "", + console_font_size: float = 10, + axes_pos: typing.Optional[typing.List] = None, + axes_length: float = 100, + axes_width: int = 2, + theme=PYNEUROML_VISPY_THEME, +): + """Create a new vispy scene canvas with a view and optional axes lines + + Reference: https://vispy.org/gallery/scene/axes_plot.html + + :param view_min: min view co-ordinates + :type view_min: [float, float, float] + :param view_max: max view co-ordinates + :type view_max: [float, float, float] + :param title: title of plot + :type title: str + :param axes_pos: position to draw axes at + :type axes_pos: [float, float, float] + :param axes_length: length of axes + :type axes_length: float + :param axes_width: width of axes lines + :type axes_width: float + :returns: scene, view + """ + canvas = scene.SceneCanvas( + keys="interactive", + show=True, + bgcolor=VISPY_THEME[theme]["bg"], + size=(800, 600), + title="NeuroML viewer (VisPy)", + ) + grid = canvas.central_widget.add_grid(margin=10) + grid.spacing = 0 + + title_widget = scene.Label(title, color=VISPY_THEME[theme]["fg"]) + title_widget.height_max = 80 + grid.add_widget(title_widget, row=0, col=0, col_span=1) + + console_widget = scene.Console( + text_color=VISPY_THEME[theme]["fg"], + font_size=console_font_size, + ) + console_widget.height_max = 80 + grid.add_widget(console_widget, row=2, col=0, col_span=1) + + bottom_padding = grid.add_widget(row=3, col=0, col_span=1) + bottom_padding.height_max = 10 + + view = grid.add_view(row=1, col=0, border_color=None) + + # create cameras + # https://vispy.org/gallery/scene/flipped_axis.html + cam1 = scene.cameras.PanZoomCamera(parent=view.scene, name="PanZoom") + + cam2 = scene.cameras.TurntableCamera(parent=view.scene, name="Turntable") + + cam3 = scene.cameras.ArcballCamera(parent=view.scene, name="Arcball") + + cam4 = scene.cameras.FlyCamera(parent=view.scene, name="Fly") + # do not keep z up + cam4.autoroll = False + + cams = [cam4, cam2] + + # console text + console_text = "Controls: reset view: 0; quit: Esc/9" + if len(cams) > 1: + console_text += "; cycle camera: 1, 2 (fwd/bwd)" + + cam_text = { + cam1: textwrap.dedent( + """ + Left mouse button: pans view; right mouse button or scroll: + zooms""" + ), + cam2: textwrap.dedent( + """ + Left mouse button: orbits view around center point; right mouse + button or scroll: change zoom level; Shift + left mouse button: + translate center point; Shift + right mouse button: change field of + view; r/R: view rotation animation""" + ), + cam3: textwrap.dedent( + """ + Left mouse button: orbits view around center point; right + mouse button or scroll: change zoom level; Shift + left mouse + button: translate center point; Shift + right mouse button: change + field of view""" + ), + cam4: textwrap.dedent( + """ + Arrow keys/WASD to move forward/backwards/left/right; F/C to move + up and down; Space to brake; Left mouse button/I/K/J/L to control + pitch and yaw; Q/E for rolling""" + ), + } + + # Turntable is default + cam_index = 1 + view.camera = cams[cam_index] + + if view_min is not None and view_max is not None: + view_center = (numpy.array(view_max) + numpy.array(view_min)) / 2 + logger.debug(f"Center is {view_center}") + cam1.center = [view_center[0], view_center[1]] + cam2.center = view_center + cam3.center = view_center + cam4.center = view_center + + for acam in cams: + x_width = abs(view_min[0] - view_max[0]) + y_width = abs(view_min[1] - view_max[1]) + z_width = abs(view_min[2] - view_max[2]) + + xrange = ( + (view_min[0] - x_width * 0.02, view_max[0] + x_width * 0.02) + if x_width > 0 + else (-100, 100) + ) + yrange = ( + (view_min[1] - y_width * 0.02, view_max[1] + y_width * 0.02) + if y_width > 0 + else (-100, 100) + ) + zrange = ( + (view_min[2] - z_width * 0.02, view_max[2] + z_width * 0.02) + if z_width > 0 + else (-100, 100) + ) + logger.debug(f"{xrange}, {yrange}, {zrange}") + + acam.set_range(x=xrange, y=yrange, z=zrange) + + for acam in cams: + acam.set_default_state() + + console_widget.write(f"Center: {view.camera.center}") + console_widget.write(console_text) + console_widget.write( + f"Current camera: {view.camera.name}: " + + cam_text[view.camera].replace("\n", " ").strip() + ) + + if axes_pos: + points = [ + axes_pos, # origin + [axes_pos[0] + axes_length, axes_pos[1], axes_pos[2]], + [axes_pos[0], axes_pos[1] + axes_length, axes_pos[2]], + [axes_pos[0], axes_pos[1], axes_pos[2] + axes_length], + ] + scene.Line( + points, + connect=numpy.array([[0, 1], [0, 2], [0, 3]]), + parent=view.scene, + color=VISPY_THEME[theme]["fg"], + width=axes_width, + ) + + def vispy_rotate(self): + view.camera.orbit(azim=1, elev=0) + + rotation_timer = Timer(connect=vispy_rotate) + + @canvas.events.key_press.connect + def vispy_on_key_press(event): + nonlocal cam_index + + # Disable camera cycling. The fly camera looks sufficient. + # Keeping views/ranges same when switching cameras is not simple. + # Prev + if event.text == "1": + cam_index = (cam_index - 1) % len(cams) + view.camera = cams[cam_index] + # next + elif event.text == "2": + cam_index = (cam_index + 1) % len(cams) + view.camera = cams[cam_index] + # for turntable only: rotate animation + elif event.text == "R" or event.text == "r": + if view.camera == cam2: + if rotation_timer.running: + rotation_timer.stop() + else: + rotation_timer.start() + # reset + elif event.text == "0": + view.camera.reset() + # quit + elif event.text == "9": + canvas.app.quit() + + console_widget.clear() + # console_widget.write(f"Center: {view.camera.center}") + console_widget.write(console_text) + console_widget.write( + f"Current camera: {view.camera.name}: " + + cam_text[view.camera].replace("\n", " ").strip() + ) + + return scene, view + + +def get_cell_bound_box(cell: Cell): + """Get a boundary box for a cell + + :param cell: cell to get boundary box for + :type cell: neuroml.Cell + :returns: tuple (min view, max view) + + """ + seg0 = cell.morphology.segments[0] # type: Segment + ex1 = numpy.array([seg0.distal.x, seg0.distal.y, seg0.distal.z]) + seg1 = cell.morphology.segments[-1] # type: Segment + ex2 = numpy.array([seg1.distal.x, seg1.distal.y, seg1.distal.z]) + center = (ex1 + ex2) / 2 + diff = numpy.linalg.norm(ex2 - ex1) + view_min = center - diff + view_max = center + diff + + return view_min, view_max diff --git a/requirements-development.txt b/requirements-development.txt index 40308a38..2d3b4d3c 100644 --- a/requirements-development.txt +++ b/requirements-development.txt @@ -2,6 +2,8 @@ argparse airspeed>=0.5.5 matplotlib graphviz +vispy +pyqt5 NEURON ppft diff --git a/requirements-experimental.txt b/requirements-experimental.txt index 40308a38..2d3b4d3c 100644 --- a/requirements-experimental.txt +++ b/requirements-experimental.txt @@ -2,6 +2,8 @@ argparse airspeed>=0.5.5 matplotlib graphviz +vispy +pyqt5 NEURON ppft diff --git a/requirements.txt b/requirements.txt index c988c09b..d40a25a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,8 @@ argparse airspeed>=0.5.5 matplotlib graphviz +vispy +pyqt5 modelspec>=0.1.3 NEURON ppft diff --git a/setup.py b/setup.py index bfd089b0..7b665516 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ "tune": ["neurotune @ git+https://github.com/NeuralEnsemble/neurotune.git@master#egg=neurotune", "inspyred @ git+https://github.com/aarongarrett/inspyred.git@master#egg=inspyred", "ppft"], + "vispy": ["vispy", "pyqt5"], } extras["all"] = sum(extras.values(), []), diff --git a/test-ghactions.sh b/test-ghactions.sh index 0fd0c846..f3d79ee1 100755 --- a/test-ghactions.sh +++ b/test-ghactions.sh @@ -135,6 +135,7 @@ if [ "$run_neuron_examples" == true ]; then echo "################################################" echo "## Try exporting morphologies to NeuroML from NEURON" + nrnivmodl # Export NeuroML v1 from NEURON example python export_neuroml1.py diff --git a/tests/plot/Izh2007Cells.net.nml b/tests/plot/Izh2007Cells.net.nml new file mode 100644 index 00000000..87cb8f17 --- /dev/null +++ b/tests/plot/Izh2007Cells.net.nml @@ -0,0 +1,97 @@ + + + + + + + Regular spiking cell + + + + + Weakly adapting cell + + + + + + Strongly adapting cell + + + + + Low threshold spiking cell + + + + + + + + + + + + + A number of different Izhikevich spiking cells + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/plot/test-spherical-soma.cell.nml b/tests/plot/test-spherical-soma.cell.nml new file mode 100644 index 00000000..697ca7e9 --- /dev/null +++ b/tests/plot/test-spherical-soma.cell.nml @@ -0,0 +1,115 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Default soma segment group for the cell + + + + + + + Default axon segment group for the cell + + + + + + + Default dendrite segment group for the cell + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/plot/test_morphology_plot.py b/tests/plot/test_morphology_plot.py index 6fe196ec..c7ba0330 100644 --- a/tests/plot/test_morphology_plot.py +++ b/tests/plot/test_morphology_plot.py @@ -11,7 +11,21 @@ import logging import pathlib as pl -from pyneuroml.plot.PlotMorphology import plot_2D, plot_interactive_3D +import pytest +import numpy +import neuroml +from pyneuroml.plot.PlotMorphology import ( + plot_2D, + plot_2D_cell_morphology, + plot_3D_cell_morphology_plotly, + plot_2D_schematic, + plot_segment_groups_curtain_plots, + plot_2D_point_cells, + plot_3D_schematic, + plot_3D_cell_morphology, + plot_interactive_3D, +) +from pyneuroml.pynml import read_neuroml2_file from .. import BaseTestCase logger = logging.getLogger(__name__) @@ -22,6 +36,31 @@ class TestMorphologyPlot(BaseTestCase): """Test Plot module""" + def test_2d_point_plotter(self): + """Test plot_2D_point_cells function.""" + nml_files = ["tests/plot/Izh2007Cells.net.nml"] + for nml_file in nml_files: + ofile = pl.Path(nml_file).name + for plane in ["xy", "yz", "xz"]: + filename = f"test_morphology_plot_2d_point_{ofile.replace('.', '_', 100)}_{plane}.png" + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + plot_2D(nml_file, nogui=True, plane2d=plane, save_to_file=filename) + + self.assertIsFile(filename) + pl.Path(filename).unlink() + + @pytest.mark.localonly + def test_3d_point_plotter(self): + """Test plot_2D_point_cells function.""" + nml_files = ["tests/plot/Izh2007Cells.net.nml"] + for nml_file in nml_files: + plot_interactive_3D(nml_file, theme="dark", nogui=True) + def test_2d_plotter(self): """Test plot_2D function.""" nml_files = ["tests/plot/Cell_497232312.cell.nml", "tests/plot/test.cell.nml"] @@ -40,12 +79,47 @@ def test_2d_plotter(self): self.assertIsFile(filename) pl.Path(filename).unlink() + def test_2d_morphology_plotter_data_overlay(self): + """Test plot_2D_cell_morphology method with data.""" + nml_files = ["tests/plot/Cell_497232312.cell.nml"] + for nml_file in nml_files: + nml_doc = read_neuroml2_file(nml_file) + cell = nml_doc.cells[0] # type: neuroml.Cell + ofile = pl.Path(nml_file).name + plane = "xy" + filename = f"test_morphology_plot_2d_{ofile.replace('.', '_', 100)}_{plane}_with_data.png" + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + segs = cell.get_all_segments_in_group("all") + values = list(numpy.random.randint(50, 101, 1800)) + list( + numpy.random.randint(0, 51, len(segs) - 1800) + ) + data_dict = dict(zip(segs, values)) + + plot_2D_cell_morphology( + cell=cell, + nogui=True, + plane2d=plane, + save_to_file=filename, + overlay_data=data_dict, + overlay_data_label="Test", + ) + + self.assertIsFile(filename) + pl.Path(filename).unlink() + def test_2d_plotter_network(self): """Test plot_2D function with a network of a few cells.""" nml_file = "tests/plot/L23-example/TestNetwork.net.nml" ofile = pl.Path(nml_file).name for plane in ["xy", "yz", "xz"]: - filename = f"test_morphology_plot_2d_{ofile.replace('.', '_', 100)}_{plane}.png" + filename = ( + f"test_morphology_plot_2d_{ofile.replace('.', '_', 100)}_{plane}.png" + ) # remove the file first try: pl.Path(filename).unlink() @@ -57,8 +131,90 @@ def test_2d_plotter_network(self): self.assertIsFile(filename) pl.Path(filename).unlink() - def test_3d_plotter(self): - """Test plot_interactive_3D function.""" + def test_2d_constant_plotter_network(self): + """Test plot_2D_schematic function with a network of a few cells.""" + nml_file = "tests/plot/L23-example/TestNetwork.net.nml" + ofile = pl.Path(nml_file).name + for plane in ["xy", "yz", "xz"]: + filename = f"test_morphology_plot_2d_{ofile.replace('.', '_', 100)}_{plane}_constant.png" + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + plot_2D( + nml_file, + nogui=True, + plane2d=plane, + save_to_file=filename, + plot_type="Constant", + ) + + self.assertIsFile(filename) + pl.Path(filename).unlink() + + def test_2d_schematic_plotter_network(self): + """Test plot_2D_schematic function with a network of a few cells.""" + nml_file = "tests/plot/L23-example/TestNetwork.net.nml" + ofile = pl.Path(nml_file).name + for plane in ["xy", "yz", "xz"]: + filename = f"test_morphology_plot_2d_{ofile.replace('.', '_', 100)}_{plane}_schematic.png" + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + plot_2D( + nml_file, + nogui=True, + plane2d=plane, + save_to_file=filename, + plot_type="Schematic", + ) + + self.assertIsFile(filename) + pl.Path(filename).unlink() + + @pytest.mark.localonly + def test_3d_schematic_plotter(self): + """Test plot_3D_schematic plotter function.""" + nml_file = "tests/plot/L23-example/HL23PYR.cell.nml" + nml_doc = read_neuroml2_file(nml_file) + cell = nml_doc.cells[0] # type: neuroml.Cell + plot_3D_schematic( + cell, + segment_groups=None, + nogui=True, + ) + + @pytest.mark.localonly + def test_3d_morphology_plotter_vispy_network(self): + """Test plot_3D_cell_morphology_vispy function.""" + nml_file = "tests/plot/L23-example/TestNetwork.net.nml" + plot_interactive_3D(nml_file, min_width=1, nogui=True, theme="dark") + + @pytest.mark.localonly + def test_3d_plotter_vispy(self): + """Test plot_3D_cell_morphology_vispy function.""" + nml_file = "tests/plot/L23-example/HL23PYR.cell.nml" + nml_doc = read_neuroml2_file(nml_file) + cell = nml_doc.cells[0] # type: neuroml.Cell + plot_3D_cell_morphology( + cell=cell, nogui=True, color="Groups", verbose=True, plot_type="Constant" + ) + + # test a circular soma + nml_file = "tests/plot/test-spherical-soma.cell.nml" + nml_doc = read_neuroml2_file(nml_file) + cell = nml_doc.cells[0] # type: neuroml.Cell + plot_3D_cell_morphology( + cell=cell, nogui=True, color="Groups", verbose=True, plot_type="Constant" + ) + + def test_3d_plotter_plotly(self): + """Test plot_3D_cell_morphology_plotly function.""" nml_files = ["tests/plot/Cell_497232312.cell.nml", "tests/plot/test.cell.nml"] for nml_file in nml_files: ofile = pl.Path(nml_file).name @@ -69,7 +225,131 @@ def test_3d_plotter(self): except FileNotFoundError: pass - plot_interactive_3D(nml_file, nogui=True, save_to_file=filename) + plot_3D_cell_morphology_plotly(nml_file, nogui=True, save_to_file=filename) self.assertIsFile(filename) pl.Path(filename).unlink() + + def test_2d_schematic_plotter(self): + """Test plot_2D_schematic function.""" + nml_file = "tests/plot/Cell_497232312.cell.nml" + olm_file = "tests/plot/test.cell.nml" + + nml_doc = read_neuroml2_file(nml_file) + cell = nml_doc.cells[0] # type: neuroml.Cell + ofile = pl.Path(nml_file).name + + olm_doc = read_neuroml2_file(olm_file) + olm_cell = olm_doc.cells[0] # type: neuroml.Cell + olm_ofile = pl.Path(olm_file).name + + for plane in ["xy", "yz", "xz"]: + # olm cell + filename = ( + f"test_schematic_plot_2d_{olm_ofile.replace('.', '_', 100)}_{plane}.png" + ) + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + plot_2D_schematic( + olm_cell, + segment_groups=["soma_0", "dendrite_0", "axon_0"], + nogui=True, + plane2d=plane, + save_to_file=filename, + ) + + # more complex cell + filename = ( + f"test_schematic_plot_2d_{ofile.replace('.', '_', 100)}_{plane}.png" + ) + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + plot_2D_schematic( + cell, + segment_groups=None, + nogui=True, + plane2d=plane, + save_to_file=filename, + labels=True, + ) + + self.assertIsFile(filename) + pl.Path(filename).unlink() + + def test_plot_segment_groups_curtain_plots(self): + """Test plot_segment_groups_curtain_plots function.""" + nml_file = "tests/plot/Cell_497232312.cell.nml" + + nml_doc = read_neuroml2_file(nml_file) + cell = nml_doc.cells[0] # type: neuroml.Cell + ofile = pl.Path(nml_file).name + + # more complex cell + filename = f"test_curtain_plot_2d_{ofile.replace('.', '_', 100)}.png" + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + sgs = cell.get_segment_groups_by_substring("apic_") + # sgs_1 = cell.get_segment_groups_by_substring("dend_") + sgs_ids = list(sgs.keys()) # + list(sgs_1.keys()) + plot_segment_groups_curtain_plots( + cell, + segment_groups=sgs_ids[0:50], + nogui=True, + save_to_file=filename, + labels=True, + ) + + self.assertIsFile(filename) + pl.Path(filename).unlink() + + def test_plot_segment_groups_curtain_plots_with_data(self): + """Test plot_segment_groups_curtain_plots function with data overlay.""" + nml_file = "tests/plot/Cell_497232312.cell.nml" + + nml_doc = read_neuroml2_file(nml_file) + cell = nml_doc.cells[0] # type: neuroml.Cell + ofile = pl.Path(nml_file).name + + # more complex cell + filename = f"test_curtain_plot_2d_{ofile.replace('.', '_', 100)}_withdata.png" + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + sgs = cell.get_segment_groups_by_substring("apic_") + sgs_1 = cell.get_segment_groups_by_substring("dend_") + sgs_ids = list(sgs.keys()) + list(sgs_1.keys()) + data_dict = {} + + nsgs = 50 + + for sg_id in sgs_ids[0:nsgs]: + lensgs = len(cell.get_all_segments_in_group(sg_id)) + data_dict[sg_id] = numpy.random.randint(0, 101, lensgs) + + plot_segment_groups_curtain_plots( + cell, + segment_groups=sgs_ids[0:nsgs], + nogui=True, + save_to_file=filename, + labels=True, + overlay_data=data_dict, + overlay_data_label="Random values (0, 100)", + width=4, + ) + + self.assertIsFile(filename) + pl.Path(filename).unlink()