-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matplotlib upgrade #270
base: main
Are you sure you want to change the base?
Matplotlib upgrade #270
Conversation
* Initial prototype * feat: Add environment tests * fix: Update esquilax version to fix type issues * docs: Add docstrings * docs: Add docstrings * test: Test multiple reward types * test: Add smoke tests and add max-steps check * feat: Implement pred-prey environment viewer * refactor: Pull out common viewer functionality * test: Add reward and view tests * test: Add rendering tests and add test docstrings * docs: Add predator-prey environment documentation page * docs: Cleanup docstrings * docs: Cleanup docstrings
* refactor: Formatting fixes * fix: Implement rewards as class * refactor: Implement observation as NamedTuple * refactor: Implement initial state generator * docs: Update docstrings * refactor: Add env animate method * docs: Link env into API docs
* feat: Prototype search and rescue environment * test: Add additional tests * docs: Update docs * refactor: Update target plot color based on status * refactor: Formatting and fix remaining typos.
* refactor: Rename to targets_remaining * docs: Formatting and expand docs * refactor: Move target and reward checks into utils module * fix: Set agent and target numbers via generator * refactor: Terminate episode if all targets found * test: Add swarms.common tests * refactor: Move agent initialisation into generator * test: Add environment utility tests
* refactor: Set -1.0 as default view value * refactor: Restructure tests * refactor: Pull out common functionality and fix formatting * refactor: Better function names
* Set plot range in viewer only * Detect targets in a single pass
* Prototype search-and-rescue network and training * Refactor and docstrings * Share rewards if found by multiple agents * Use Distrax for normal distribution * Add bijector for continuous action space * Reshape returned from actor-critic * Prototype tanh bijector w clipping * Fix random agent and cleanup * Customisable reward aggregation * Cleanup * Configurable vision model * Docstrings and cleanup params
* Add observation including all targets * Consistent test module names * Use CNN embedding
* Use channels view parameters * Rename parameters * Include step-number in observation * Add velocity field to targets * Add time scaled reward function
* Update docstrings * Update tests * Update environment readme
Wow this is an awesome addition, thanks @zombie-einstein! Which environments are still remaining? Also I think the way you've done it with |
Apart from rubix cube they all need a look at, with varying degrees of complexity. I think the general pattern needs to be something like a Some also currently have actors that are created dynamically during plotting I've been trying to work out the best pattern for, in some cases it seems it would be more efficient to initialise all the actors ahead of time, and update their states as needed. E.g. for the sudoko environment initialise all the text actors and update the text during animation (rather than create them on the fly). So there's a decent chunk of work to refactor all the plots. One option could be merge something like this PR, which will updated the requirement and passes linting/tests, and still create plots (just less efficiently), and then incrementally update the individual viewers in separate PRs to make best use of the new API. |
Hi @zombie-einstein sorry for leaving this hanging, I needed a bit of time off 😂 I'll have a deeper look at it this week and give some a proper opinion 😄 |
…tlib_upgrade # Conflicts: # requirements/requirements.txt
Hey @sash-a I fixed the graph-colouring, MMST, sodoku, and TSP viewers (still need to work out the img path for the knapsack viewer). Is the animate function intended to display states over multiple episodes? For example for the MMST env, the structure of the graph is resampled when the environment is reset. It's nice if the edges were static, then you don't need to recalculate the node positions etc and can just update the edge & node properties, and also save on the relatively node position calculation. But then this causes an issue if the states span multiple episodes. There's maybe a way to do this where you check if the episode has terminated, though a bit more involved. |
So it currently does work over multiple episodes and I think we should try keep that functionality. Maybe there is a way to tell if an episode is done from the state? |
For knapsack what should work is if you follow the method used in TSP for loading images. I think all that should be needed to move the file into the jumanji repo instead of reading it from docs as the We recently fixed TSP and I assume missed knapsack. |
Maybe comparing what has changes between two steps? Though would be very environment specific. Some more generic solutions could be
Cheers that's an easy enough fix then |
I would go with the lower impact option as the other option is a big API change. I don't think it's a big issue if the rendering is a bit slower than it could be 🤷 |
Hey @sash-a, I think this is all working now, though some of the code definitely needs a work over. Turned out to be a bit more involved:
I also grabbed a couple other changes whilst looking at this:
Wanted to check:
|
Hi @zombie-einstein thanks for this I should have time to take a look today |
Hi @zombie-einstein sorry for not having a look at this, the last two weeks have been a bit crazy
I think that would be nice, but also happy to do this in a subsequent PR?
The main reason I'd like to do this is so that we don't pin to an old matplotlib version anymore so I do think we should update it in all cases, but if it's a bit slower than the old functionality I wouldn't worry because the viewers don't need to be too fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. Just a few very minor changes. Can you please explain why there are so many changes for the graph envs? I get there needs to be changes on reset, but not sure why things changed outside of the if not jnp.array_equal(...)
block? (Not saying that it's wrong just curious why)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename the file to graph_view_utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is in examples can you rather call this visualize_random_agent.py
NODE_SIZE = 0.01 | ||
ROUTE_NODE_SIZE = 100 | ||
DEPOT_SIZE = 0.04 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did this need to change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So (somewhat annoyingly) if you draw nodes/points with a circle-artist the scale (if I remember rightly) is a fraction of the region (hence the fractional values), but using a scatter plot the larger integer value gives sensible node sizes.
This could change though, I think it's generally sensible to draw all the points in one go using a scatter, rather than individual artists.
while routes: | ||
quiver, route_nodes = routes.pop() | ||
quiver.remove() | ||
route_nodes.remove() | ||
updated.append(quiver) | ||
updated.append(route_nodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain in a comment what's happening here I'm not quite following
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah will do, I may add a longer comment on the interface outlining how some of the animation functionality is implemented.
In this case the members of the routes, and number of nodes are not static, and so the artists representing them need to be fully replaced, by removing the current ones. You can't pass state (I think) as part of the argument to the update function and so need to keep references to the existing artists in the scope outside the update function. So in the above the steps are:
- Pop the current artists from the outer list
- Call the remove method to remove them from the plot
- Add them to the list of updated artists returned from the update function
- Create new artists and push them in the outer list for use in later updates
I've used a list to act as mutable storage for artists mainly for convenience, maybe a global or class variable would look nicer. Also quite a few lines of code here, was looking to review if there is a neater way to implement.
Sure, so a couple things for graphs in general:
Compared to other environments the graphs are a pain due to all the nodes/edges moving episode, and the potential variable number of artists each step. In comparison for something like Suduko the grid is static, and you know there is a fixed grid of numbers that can just be updated in the animation, you don't need to mess about adding and removing artists inside the update function! |
Cool, I might just grab it, as where implementations vary it kind of makes it hared to refactor when trying to work out what different implementations are doing.
Everything is functional now with the updated version (the graphs were really the ones that didn't play nicely with the new API) updates would be mostly about efficient use of new API, e.g. updating artists rather then continually redrawing. My main concern was a the size of the PR, but I can chip away at the other envs if you don't mind it. |
Hey @sash-a, thought I'd grab a look at #253.
Changes here update Matplotlib version and then makes changes to pass type checking, so it builds and passes tests.
But ...
The new API now requires that the animation update function returns a list of artists that have been updated, these are objects returned from plotting calls like a plot object, or text box (see the table here). The state/data of these objects can be updated without redrawing the whole thing.
I've done this rubiks cube environment:
Doing this for all environments is easier for some than others given how they've been written, it means passing all the relevant artists to the animate function when the plot is created, and then correctly updating them inside the update function.
I think the animations will still work as is, as they still generally update the axes by reference inside the animation function, it's just less efficient then the new API that just updates the relevant artists.