Skip to content
Merged
48 changes: 46 additions & 2 deletions pyneuroml/plot/Plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import typing
import matplotlib
import matplotlib.axes
import matplotlib.animation as animation
from typing import Optional

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,7 +43,9 @@ def generate_plot(
cols_in_legend_box: int = 3,
legend_position: typing.Optional[str] = "best",
show_plot_already: bool = True,
animate: bool = False,
save_figure_to: typing.Optional[str] = None,
save_animation_to: typing.Optional[str] = None,
title_above_plot: bool = False,
verbose: bool = False,
close_plot: bool = False,
Expand Down Expand Up @@ -124,8 +127,12 @@ def generate_plot(
:type legend_position: str
:param show_plot_already: if plot should be shown when created (default: True)
:type show_plot_already: boolean
:param animate: if shown plot should be animated. show_splot_already should be True (default: False)
:type animate: boolean
:param save_figure_to: location to save generated figure to (default: None)
:type save_figure_to: str
:param save_animation_to: location to save generated animation to (default: None)
:type save_animation_to: str
:param title_above_plot: enable/disable title above the plot (default: False)
:type title_above_plot: boolean
:param verbose: enable/disable verbose logging (default: False)
Expand Down Expand Up @@ -174,6 +181,7 @@ def generate_plot(
if not show_yticklabels:
ax.set_yticklabels([])

artists = []
for i in range(len(xvalues)):
linestyle = rcParams["lines.linestyle"] if not linestyles else linestyles[i]
label = "" if not labels else labels[i]
Expand All @@ -182,7 +190,7 @@ def generate_plot(
markersize = rcParams["lines.markersize"] if not markersizes else markersizes[i]

if colors:
plt.plot(
(artist,) = plt.plot(
xvalues[i],
yvalues[i],
marker=marker,
Expand All @@ -193,7 +201,7 @@ def generate_plot(
label=label,
)
else:
plt.plot(
(artist,) = plt.plot(
xvalues[i],
yvalues[i],
marker=marker,
Expand All @@ -202,6 +210,7 @@ def generate_plot(
linewidth=linewidth,
label=label,
)
artists.append(artist)

if labels:
if legend_position == "outer right":
Expand Down Expand Up @@ -240,6 +249,41 @@ def generate_plot(
logger.info("Saved image to %s of plot: %s" % (save_figure_to, title))

if show_plot_already:
if animate:
d = 5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to test this out more. I thought the idea was that if I set duration to 5000ms, the animation would be 5000ms, but I think you're saying that this is not the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that is what is happening with my current implementation.
My calculations could be improved I think and it should be possible to atleast get as close as possible to set duration, let me try again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't worry about getting it exactly right---close enough will do and it can be tweaked later

duration = d * 1000 # in ms
size = len(xvalues[0])
interval = 50 # Delay between frames in milliseconds
pockets = duration // interval
skip = max(size // pockets, 1)
logger.info(
"Animation hyperparameters : duration=%s, size=%s, interval=%s, pockets=%s, skip=%s" % (duration, size, interval, pockets, skip))

def update(frame):
for i, artist in enumerate(artists):
artist.set_xdata(xvalues[i][:frame*skip])
artist.set_ydata(yvalues[i][:frame*skip])
return artists

ani = animation.FuncAnimation(
fig=fig,
frames=size-1,
func=update,
interval=interval,
blit=True,
cache_frame_data=False
)

frame_length_threshold = 5000
if len(xvalues[0]) < frame_length_threshold and save_animation_to:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to limit animations to 5000 frames only? If this depends on the users' setups (how good their hardware etc. is), let's not add this check like this. Maybe we can just add a warning that is thrown if there are more than 5000 points that says "too many data points, this may take a while"? That way the user can still go ahead and use it if they wish. Maybe they'll start it and go eat lunch, and it'll be complete by the time they return (we do this a lot with our simulations :))

Copy link
Contributor Author

@YBCS YBCS May 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I will add a warning and it will ask for user input to continue with the saving or not (possibly showing approx time (minutes) it might take to complete the entire process. )
It will then save or discard based on the user input.
How does that sound ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd skip the user input entirely. Just make it an argument and if it's more than something, just print a warning saying "this could take a while". If it takes too long, the user can always interrupt it and use a smaller value.

(Adding interaction to take user input will make it trickier to do batch analyses, so we want to avoid that as much as possible)

logger.info("Saving animation to %s" %
(save_animation_to))
ani.save(
filename=save_animation_to,
writer="pillow",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should make this an argument too, to allow the user to choose different writers? ffmpeg/imagemagick may be quicker than pillow and support more formats:

https://matplotlib.org/stable/users/explain/animations/animations.html

How about making it an argument of the form:

writer={ 'writer name': ["extra args list"] }

with default value:

writer = { "pillow": [] }

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add this link to the docstring around these arguments:

https://matplotlib.org/stable/users/explain/animations/animations.html#saving-animations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I include this in the same PR or should it be done in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can roll this in here too?

progress_callback=lambda i, n: print(
f'Saving frame {i}/{n}')
)
plt.show()

if close_plot:
Expand Down
34 changes: 33 additions & 1 deletion tests/plot/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
Copyright 2023 NeuroML contributors
"""

import random
import pytest
import unittest
import logging
import pathlib as pl
Expand All @@ -19,9 +21,39 @@


class TestPlot(BaseTestCase):

"""Test Plot module"""

@pytest.mark.localonly
def test_generate_plot_animated(self):
"""Test generate_plot function."""
filename = "tests/plot/test_generate_plot.gif"

# remove the file first
try:
pl.Path(filename).unlink()
except FileNotFoundError:
pass

numpoints = 100
xs = list(range(0, numpoints))
ys = random.choices(list(range(0, 1000)), k=numpoints)
ys2 = random.choices(list(range(0, 1500)), k=numpoints)

generate_plot(
[xs, xs],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[xs1, xs2] is it possible that this list contains items with differing lengths. can xs1 and xs2 have different lengths ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think they can, but [y1, y2] must have the same lengths as [x1, x2] so that matplotlib has complete (x,y) co-ordinates for each point.

For animation, one will have to find the values with the longest length and use that to calculate duration and all that then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"For animation, one will have to find the values with the longest length and use that to calculate duration and all that then?"

Yeah that will be correct as, [plt1, plt2] are individual plots and we only care of the total duration which means the max of the lengths

[ys, ys2],
"Test plot",
xaxis="x",
yaxis="y",
grid=False,
show_plot_already=True,
animate=True,
legend_position="right",
save_animation_to=filename
)
self.assertIsFile(filename)
pl.Path(filename).unlink()

def test_generate_plot(self):
"""Test generate_plot function."""
filename = "tests/plot/test_generate_plot.png"
Expand Down