Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add save_video util and deprecate RecordVideo in favor of it #3016

Merged
merged 9 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions gym/utils/save_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Utility functions to save rendering videos."""
import os
from typing import Callable, Optional

import gym
from gym import logger

try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError:
raise gym.error.DependencyNotInstalled(
"MoviePy is not installed, run `pip install moviepy`"
)


def capped_cubic_video_schedule(episode_id: int) -> bool:
"""The default episode trigger.

This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...

Args:
episode_id: The episode number

Returns:
If to apply a video schedule number
"""
if episode_id < 1000:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
else:
return episode_id % 1000 == 0


def save_video(
frames: list,
video_folder: str,
episode_trigger: Callable[[int], bool] = None,
step_trigger: Callable[[int], bool] = None,
video_length: Optional[int] = None,
name_prefix: str = "rl-video",
episode_index: int = 0,
step_starting_index: int = 0,
**kwargs,
):
"""Save videos from rendering frames.

This function extract video from a list of render frame episodes.

Args:
frames (List[RenderFrame]): A list of frames to compose the video.
video_folder (str): The folder where the recordings will be stored
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very confused - why would episode_trigger still be here? It seems strange when users execute save_video the videos might not be saved…

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the beginning, the expected type of frames was a List of episodes, where an episode is a List of frames.
Talking with @pseudo-rnd-thoughts, we agree it is simpler as it is now, with frames as a List of frames of the same episode.

For how it is right now, yes, we can drop episode_trigger and the user can just do it with an if statement.
Notice that it still can happen that save_video doesn't save anything depending on step_trigger

step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
video_length (int): The length of recorded episodes. If it isn't specified, the entire episode is recorded.
Otherwise, snippets of the specified length are captured.
name_prefix (str): Will be prepended to the filename of the recordings.
episode_index (int): The index of the current episode.
step_starting_index (int): The step index of the first frame.
**kwargs: The kwargs that will be passed to moviepy's ImageSequenceClip.
You need to specify either fps or duration.

Example:
>>> import gym
>>> from gym.utils.save_video import save_video
>>> env = gym.make("FrozenLake-v1", render_mode="rgb_array")
>>> env.reset()
>>> step_starting_index = 0
>>> episode_index = 0
>>> for step_index in range(199):
... action = env.action_space.sample()
... _, _, done, _ = env.step(action)
... if done:
... save_video(
... env.render(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like here only calls env.render() when the env is done. Does this mean when calling env = gym.make("FrozenLake-v1", render_mode="rgb_array") each frame is rendered and cached? If this is the case this could be an undesirable bottleneck — we might only keep the rendered frames of a few episodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with the current Render API env.render() returns the list of frames of the whole episode when the mode is rgb_array; old-like behavior can be obtained with single_rgb_array.

When you call render() or reset(), the frame list is cleaned

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I am struggling to understand this. If I don’t want to render every frame, how can I achieve it with the current API? What is the old-like behavior with single_rgb_array?

Copy link
Contributor Author

@younik younik Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to render a single frame representing the current state, you can achieve it with single_rgb_array, thus similar to before:

env = gym.make("MyEnv", render_mode="single_rgb_array")
env.reset()

for _ in range(n_steps):
    env.step(...)
    env.render()  # single frame

env.close()

Otherwise, with rgb_array:

env = gym.make("MyEnv", render_mode="rgb_array")
env.reset()

for _ in range(n_steps):
    _, _, done, _ = env.step(...)
   if done:
      break

env.render()  # List of all frames
env.close()

... "videos",
... fps=env.metadata["render_fps"],
... step_starting_index=step_starting_index,
... episode_index=episode_index
... )
... step_starting_index = step_index + 1
... episode_index += 1
... env.reset()
>>> env.close()
"""
if not isinstance(frames, list):
logger.error(
f"Expected a list of frames, got a {frames.__class__.__name__} instead."
)
if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule

video_folder = os.path.abspath(video_folder)
os.makedirs(video_folder, exist_ok=True)
path_prefix = f"{video_folder}/{name_prefix}"

if episode_trigger is not None and episode_trigger(episode_index):
clip = ImageSequenceClip(frames[:video_length], **kwargs)
clip.write_videofile(f"{path_prefix}-episode-{episode_index}.mp4")

if step_trigger is not None:
# skip the first frame since it comes from reset
for step_index, frame_index in enumerate(
range(1, len(frames)), start=step_starting_index
):
if step_trigger(step_index):
end_index = (
frame_index + video_length if video_length is not None else None
)
clip = ImageSequenceClip(frames[frame_index:end_index], **kwargs)
clip.write_videofile(f"{path_prefix}-step-{step_index}.mp4")
11 changes: 10 additions & 1 deletion gym/wrappers/record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
return episode_id % 1000 == 0


class RecordVideo(gym.Wrapper):
class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
"""This wrapper records videos of rollouts.

Usually, you only want to record episodes intermittently, say every hundredth episode.
Expand All @@ -36,6 +36,10 @@ class RecordVideo(gym.Wrapper):
By default, the recording will be stopped once a `terminated` or `truncated` signal has been emitted by the environment. However, you can
also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for
``video_length``.

Note:
RecordVideo is deprecated.
Collect the frames with render_mode='rgb_array' and use gym/utils/save_video.py
"""

def __init__(
Expand All @@ -62,6 +66,11 @@ def __init__(
"""
super().__init__(env, new_step_api)

logger.deprecation(
"RecordVideo is deprecated.\n"
"Collect the frames with render_mode='rgb_array' and use gym/utils/save_video.py"
)

if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
"mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
"toy_text": ["pygame==2.1.0"],
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0"],
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
}

# Testing dependency groups.
Expand Down
104 changes: 104 additions & 0 deletions tests/utils/test_save_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import shutil

import gym
from gym.utils.save_video import capped_cubic_video_schedule, save_video


def test_record_video_using_default_trigger():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)

env.reset()
step_starting_index = 0
episode_index = 0
for step_index in range(199):
action = env.action_space.sample()
_, _, done, _ = env.step(action)
if done:
save_video(
env.render(),
"videos",
fps=env.metadata["render_fps"],
step_starting_index=step_starting_index,
episode_index=episode_index,
)
step_starting_index = step_index + 1
episode_index += 1
env.reset()

env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == sum(
capped_cubic_video_schedule(i) for i in range(episode_index)
)
shutil.rmtree("videos")


def modulo_step_trigger(mod: int):
def step_trigger(step_index):
return step_index % mod == 0

return step_trigger


def test_record_video_step_trigger():
env = gym.make("CartPole-v1", render_mode="rgb_array")
env._max_episode_steps = 20

env.reset()
step_starting_index = 0
episode_index = 0
for step_index in range(199):
action = env.action_space.sample()
_, _, done, _ = env.step(action)
if done:
save_video(
env.render(),
"videos",
fps=env.metadata["render_fps"],
step_trigger=modulo_step_trigger(100),
step_starting_index=step_starting_index,
episode_index=episode_index,
)
step_starting_index = step_index + 1
episode_index += 1
env.reset()
env.close()

assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == 2
shutil.rmtree("videos")


def test_record_video_within_vector():
envs = gym.vector.make(
"CartPole-v1", num_envs=2, asynchronous=True, render_mode="rgb_array"
)
envs = gym.wrappers.RecordEpisodeStatistics(envs)
envs.reset()
episode_frames = []
step_starting_index = 0
episode_index = 0
for step_index in range(199):
_, _, _, infos = envs.step(envs.action_space.sample())
episode_frames.extend(envs.call("render")[0])

if "episode" in infos and infos["_episode"][0]:
save_video(
episode_frames,
"videos",
fps=envs.metadata["render_fps"],
step_trigger=modulo_step_trigger(100),
step_starting_index=step_starting_index,
episode_index=episode_index,
)
episode_frames = []
step_starting_index = step_index + 1
episode_index += 1

assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == 2
shutil.rmtree("videos")
102 changes: 0 additions & 102 deletions tests/wrappers/test_record_video.py

This file was deleted.