Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matplotlib upgrade #270

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
3597d9e
feat: Implement predator prey env (#1)
zombie-einstein Nov 4, 2024
c955320
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 4, 2024
6b34657
Merge branch 'main' into main
sash-a Nov 4, 2024
988339b
fix: PR fixes (#2)
zombie-einstein Nov 5, 2024
a0fe7a5
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 5, 2024
b4cce01
style: Run updated pre-commit
zombie-einstein Nov 6, 2024
cb6d88d
refactor: Consolidate predator prey type
zombie-einstein Nov 7, 2024
06de3a0
feat: Implement search and rescue (#3)
zombie-einstein Nov 11, 2024
34beab6
fix: PR fixes (#4)
zombie-einstein Nov 14, 2024
f5fa659
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 15, 2024
072db18
refactor: PR fixes (#5)
zombie-einstein Nov 19, 2024
162a74d
feat: Allow variable environment dimensions (#6)
zombie-einstein Nov 19, 2024
4996869
Merge branch 'main' into main
zombie-einstein Nov 22, 2024
6322f61
fix: Locate targets in single pass (#8)
zombie-einstein Nov 23, 2024
4ba7688
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 28, 2024
9a654b9
feat: training and customisable observations (#7)
zombie-einstein Dec 7, 2024
5021e20
feat: view all targets (#9)
zombie-einstein Dec 9, 2024
c5c7b85
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
13ffb84
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
9e8ac5c
feat: Scaled rewards and target velocities (#10)
zombie-einstein Dec 11, 2024
5c509c7
Pass shape information to timesteps (#11)
zombie-einstein Dec 11, 2024
8acf242
test: extend tests and docs (#12)
zombie-einstein Dec 11, 2024
1792aa6
fix: unpin jax requirement
zombie-einstein Dec 12, 2024
1e66e78
Include agent positions in observation (#13)
zombie-einstein Dec 12, 2024
296e98a
Update animation functions
zombie-einstein Dec 12, 2024
fe8880a
Update rubiks cube viewer for new API
zombie-einstein Dec 13, 2024
407ff79
Upgrade Esquilax and remove unused random keys (#14)
zombie-einstein Dec 27, 2024
b52fefd
Address PR comments
zombie-einstein Jan 9, 2025
04fe710
docs: Review docstrings and docs (#15)
zombie-einstein Jan 12, 2025
ac3f811
fix: Remove enum annotations
zombie-einstein Jan 12, 2025
943a51b
refactor: address pr comments (#16)
zombie-einstein Jan 17, 2025
6a3fdb1
Parameter tweaks
zombie-einstein Jan 17, 2025
ac8838f
refactor: Observation tweaks (#17)
zombie-einstein Jan 20, 2025
05eeedf
refactor: address pr comments (#18)
zombie-einstein Feb 3, 2025
5353ef7
chore: revert to using set colours
sash-a Feb 4, 2025
bc9e252
fix: minor training bug due to refactor
sash-a Feb 4, 2025
eac2f1f
chore: update default parameters to ones tested in mava
sash-a Feb 4, 2025
79a7aa8
chore: add search and rescue to the readme
sash-a Feb 4, 2025
de8e869
Update graph-coloring viewer
zombie-einstein Feb 4, 2025
243ddde
Update mmst
zombie-einstein Feb 4, 2025
211d578
Merge branch 'main' of github.com:zombie-einstein/jumanji into matplo…
zombie-einstein Feb 4, 2025
cbb159d
Merge branch 'instadeepai:main' into main
zombie-einstein Feb 4, 2025
385935d
Merge branch 'main' of github.com:zombie-einstein/jumanji into matplo…
zombie-einstein Feb 4, 2025
ffb9c6f
Update suduko viewer
zombie-einstein Feb 5, 2025
cadb337
Fix tsp viewer
zombie-einstein Feb 5, 2025
3bcb2a5
Refactor graph coloring for multiple episodes
zombie-einstein Feb 10, 2025
5b6ada7
Fix coloring and mmst for multiple episodes
zombie-einstein Feb 12, 2025
759cc45
Fix knapsack image loading
zombie-einstein Feb 12, 2025
5361a75
Replace pkg_resources usage and add animation test script
zombie-einstein Feb 12, 2025
b5613a1
Fix refresh of cvrp on new episode
zombie-einstein Feb 12, 2025
dbc52bb
Refresh tsp at new episode
zombie-einstein Feb 12, 2025
4c3a139
Refresh multi-crvp at new episode
zombie-einstein Feb 12, 2025
e15e70e
Cleanup and refactor graph layout functionality
zombie-einstein Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/viewer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import jax
import requests
from hydra import compose, initialize

from jumanji.training.setup_train import setup_agent, setup_env

envs = [
"bin_pack",
"cleaner",
"connector",
"cvrp",
"flat_pack",
"game_2048",
"graph_coloring",
"job_shop",
"knapsack",
"maze",
"minesweeper",
"mmst",
"multi_cvrp",
"pac_man",
"robot_warehouse",
"lbf",
"rubiks_cube",
"sliding_tile_puzzle",
"snake",
"sokoban",
"sudoku",
"tetris",
"tsp",
]


def download_file(url: str, file_path: str) -> None:
# Send an HTTP GET request to the URL
response = requests.get(url)
# Check if the request was successful (status code 200)
if response.status_code == 200:
with open(file_path, "wb") as f:
f.write(response.content)
else:
print("Failed to download the file.")


def create_animation(env_name: str, agent: str = "random", num_episodes: int = 2) -> None:
print(f"Animating {env_name}")

os.makedirs("configs", exist_ok=True)
config_url = "https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml"
download_file(config_url, "configs/config.yaml")
env_url = f"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env_name}.yaml"
os.makedirs("configs/env", exist_ok=True)
download_file(env_url, f"configs/env/{env_name}.yaml")
os.makedirs("animations", exist_ok=True)

with initialize(version_base=None, config_path="configs"):
cfg = compose(config_name="config.yaml", overrides=[f"env={env_name}", f"agent={agent}"])

env = setup_env(cfg).unwrapped
agent = setup_agent(cfg, env)
policy = jax.jit(agent.make_policy(stochastic=False))

reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)
states = []
key = jax.random.PRNGKey(cfg.seed)

for _ in range(num_episodes):
key, reset_key = jax.random.split(key)
state, timestep = reset_fn(reset_key)
states.append(state)

while not timestep.last():
key, action_key = jax.random.split(key)
observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
action = policy(observation, action_key)
state, timestep = step_fn(state, action.squeeze(axis=0))
states.append(state)

env.animate(states, 100, f"animations/{env_name}_animation.gif")


if __name__ == "__main__":
cli = argparse.ArgumentParser()
cli.add_argument(
"envs",
nargs="*",
type=str,
default=None,
)

args = cli.parse_args()
arg_envs = args.envs
env_list = envs if len(arg_envs) == 0 else args.envs

for env in env_list:
try:
create_animation(env)
except Exception as e:
print(f"{env} failed", e)
16 changes: 7 additions & 9 deletions jumanji/environments/logic/game_2048/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def animate(
# Set up the figure and axes for the game board.
fig, ax = self.get_fig_ax()
fig.suptitle("2048 Score: 0", size=20)
plt.tight_layout()
plt.close(fig)

# Define a function to animate a single game state.
def make_frame(state: State) -> Tuple[Artist]:
Expand Down Expand Up @@ -131,21 +131,19 @@ def get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
"""
# Check if a figure with an id "2048" already exists.
exists = plt.fignum_exists(self._name)
fig = plt.figure(
self._name,
figsize=(6.0, 6.0),
facecolor=self.COLORS["bg"],
)
if exists:
# If it exists, get the figure and axes objects.
fig = plt.figure(self._name)
ax = fig.get_axes()[0]
else:
# If it doesn't exist, create a new figure and axes objects.
fig = plt.figure(
self._name,
figsize=(6.0, 6.0),
facecolor=self.COLORS["bg"],
)
plt.tight_layout()
if not plt.isinteractive():
fig.show()
ax = fig.add_subplot()
ax = fig.add_subplot(111)
return fig, ax

def render_tile(self, tile_value: int, ax: plt.Axes, row: int, col: int) -> None:
Expand Down
83 changes: 60 additions & 23 deletions jumanji/environments/logic/graph_coloring/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from itertools import pairwise
from typing import List, Optional, Sequence, Tuple

import chex
Expand Down Expand Up @@ -62,15 +63,40 @@ def animate(
plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
ax = fig.add_subplot(111)
plt.close(fig)
circles = self._prepare_figure(ax, states[0])
nodes, labels, edges = self._prepare_figure(ax, states[0])

def make_frame(state: State) -> List[Artist]:
for circle, color in zip(circles, state.colors, strict=False):
def make_frame(state_pair: Tuple[State, State]) -> List[Artist]:
prev_state, state = state_pair

for circle, color in zip(nodes, state.colors, strict=False):
circle.set(color=self._color_mapping[color])
return circles
# Update node and edges if new episode
if not np.array_equal(prev_state.adj_matrix, state.adj_matrix):
pos = self._spring_layout(state.adj_matrix, self.num_nodes)
for circle, label, xy in zip(nodes, labels, pos, strict=False):
circle.set_center(xy)
label.set(x=xy[0], y=xy[1])
n = 0
for i in range(self.num_nodes):
for j in range(i + 1, self.num_nodes):
edges[n].set(
xdata=[pos[i][0], pos[j][0]],
ydata=[pos[i][1], pos[j][1]],
visible=state.adj_matrix[i, j],
)
n += 1

return nodes + edges

else:
return nodes

_animation = animation.FuncAnimation(
fig, make_frame, frames=states[1:], interval=interval, blit=False
fig,
make_frame,
frames=pairwise(states),
interval=interval,
blit=False,
)

if save_path:
Expand All @@ -83,15 +109,17 @@ def _set_params(self, state: State) -> None:
self.node_scale = self._calculate_node_scale(self.num_nodes)
self._color_mapping = self._create_color_mapping(self.num_nodes)

def _prepare_figure(self, ax: plt.Axes, state: State) -> List[Artist]:
ax.set_xlim(-0.5, 0.50)
ax.set_ylim(-0.50, 0.50)
def _prepare_figure(
self, ax: plt.Axes, state: State
) -> Tuple[List[plt.Circle], List[plt.Text], List[plt.Line2D]]:
ax.set_xlim(-1.0, 1.0)
ax.set_ylim(-1.0, 1.0)
ax.set_aspect("equal")
ax.axis("off")
pos = self._spring_layout(state.adj_matrix, self.num_nodes)
self._render_edges(ax, pos, state.adj_matrix, self.num_nodes)
artists = self._render_nodes(ax, pos, state.colors)
return artists
edges = self._render_edges(ax, pos, state.adj_matrix, self.num_nodes)
nodes, labels = self._render_nodes(ax, pos, state.colors)
return nodes, labels, edges

def close(self) -> None:
plt.close(self._name)
Expand Down Expand Up @@ -203,11 +231,12 @@ def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:

def _render_nodes(
self, ax: plt.Axes, pos: List[Tuple[float, float]], colors: chex.Array
) -> List[Artist]:
) -> Tuple[List[plt.Circle], List[plt.Text]]:
# Set the radius of the nodes as a fraction of the scale,
# so nodes appear smaller when there are more of them.
node_radius = 0.05 * 5 / self.node_scale
circles = []
labels = []

for i, (x, y) in enumerate(pos):
c = plt.Circle(
Expand All @@ -219,7 +248,7 @@ def _render_nodes(
)
circles.append(c)
ax.add_artist(c)
ax.text(
label = plt.Text(
x,
y,
str(i),
Expand All @@ -229,25 +258,33 @@ def _render_nodes(
weight="bold",
zorder=200,
)
labels.append(label)
ax.add_artist(label)

return circles
return circles, labels

def _render_edges(
self,
ax: plt.Axes,
pos: List[Tuple[float, float]],
adj_matrix: chex.Array,
num_nodes: int,
) -> None:
) -> List[plt.Line2D]:
edges = []

for i in range(num_nodes):
for j in range(i + 1, num_nodes):
if adj_matrix[i, j]:
ax.plot(
[pos[i][0], pos[j][0]],
[pos[i][1], pos[j][1]],
color=self._color_mapping[-1],
linewidth=0.5,
)
edge = plt.Line2D(
[pos[i][0], pos[j][0]],
[pos[i][1], pos[j][1]],
color=self._color_mapping[-1],
linewidth=0.5,
visible=adj_matrix[i, j],
)
ax.add_artist(edge)
edges.append(edge)

return edges

def _calculate_node_scale(self, num_nodes: int) -> int:
# Set the scale of the graph based on the number of nodes,
Expand All @@ -262,6 +299,6 @@ def _create_color_mapping(
colormap = cm.get_cmap("hsv", num_nodes + 1)
color_mapping = []
for colormap_idx in colormap_indices:
color_mapping.append(colormap(colormap_idx))
color_mapping.append(colormap(float(colormap_idx)))
color_mapping.append((0.0, 0.0, 0.0, 1.0)) # Adding black to the color mapping
return color_mapping
12 changes: 5 additions & 7 deletions jumanji/environments/logic/sudoku/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.artist import Artist
from matplotlib.animation import FuncAnimation
from matplotlib.text import Text

import jumanji
Expand Down Expand Up @@ -80,14 +80,14 @@ def animate(
states: Sequence[State],
interval: int = 500,
save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
) -> FuncAnimation:
fig, ax = plt.subplots(figsize=(6, 6))
plt.title(f"{self._name}")
texts = self._draw(ax, states[0])

board_shape = states[0].board.shape

def make_frame(state: State) -> List[Artist]:
def make_frame(state: State) -> List[plt.Text]:
updated = []
for i in range(board_shape[0]):
for j in range(board_shape[1]):
Expand All @@ -99,9 +99,7 @@ def make_frame(state: State) -> List[Artist]:

return updated

animation = matplotlib.animation.FuncAnimation(
fig, make_frame, frames=states[1:], interval=interval, blit=False
)
animation = FuncAnimation(fig, make_frame, frames=states[1:], interval=interval, blit=False)

if save_path:
animation.save(save_path)
Expand Down Expand Up @@ -136,7 +134,7 @@ def _draw_figures(self, ax: plt.Axes, state: State) -> List[List[Text]]:
"""Loop over the different cells and draws corresponding shapes in the ax object."""
board = state.board
board_shape = board.shape
artists = list()
artists: List[List[Text]] = list()

for i in range(board_shape[0]):
artists.append([])
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions jumanji/environments/packing/knapsack/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from importlib import resources
from typing import Callable, Optional, Sequence, Tuple

import matplotlib.animation
Expand Down Expand Up @@ -104,7 +105,7 @@ def make_frame(state: State) -> Tuple[Artist]:
self._animation = matplotlib.animation.FuncAnimation(
fig,
make_frame,
frames=len(states),
frames=states,
interval=interval,
)

Expand Down Expand Up @@ -141,8 +142,9 @@ def _prepare_figure(self, ax: plt.Axes) -> None:
ax.set_ylim(0, 1)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
map_img = plt.imread("docs/img/knapsack.png")
ax.imshow(map_img, extent=[0, 1, 0, 1])
img_path = resources.files(jumanji.environments.packing.knapsack) / "img/knapsack.png"
sack_img = plt.imread(img_path)
ax.imshow(sack_img, extent=(0, 1, 0, 1))

def _display_human(self, fig: plt.Figure) -> None:
if plt.isinteractive():
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/packing/tetris/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def render(self, state: State) -> Optional[NDArray]:
"""Render Tetris.

Args:
grid: the grid of the Tetris environment to render.
state: State of the Tetris environment to render.

Returns:
RGB array if the render_mode is RenderMode.RGB_ARRAY.
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/routing/connector/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, name: str, num_agents: int, render_mode: str = "human") -> No

self.colors = [(1.0, 1.0, 1.0, 1.0)] # Initial color must be white.
for colormap_idx in colormap_indecies:
self.colors.append(colormap(colormap_idx))
self.colors.append(colormap(float(colormap_idx)))

# The animation must be stored in a variable that lives as long as the
# animation should run. Otherwise, the animation will get garbage-collected.
Expand Down
Loading