-
Notifications
You must be signed in to change notification settings - Fork 8.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add save_video util and deprecate RecordVideo in favor of it #3016
Changes from 4 commits
e9bd988
5cc14da
e5c34c5
5d50131
ee90762
ae34778
ffcfe75
5c0074b
26c0108
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
"""Utility functions to save rendering videos.""" | ||
import os | ||
from typing import Callable, Optional | ||
|
||
import gym | ||
from gym import logger | ||
|
||
try: | ||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip | ||
except ImportError: | ||
raise gym.error.DependencyNotInstalled( | ||
"MoviePy is not installed, run `pip install moviepy`" | ||
) | ||
|
||
|
||
def capped_cubic_video_schedule(episode_id: int) -> bool: | ||
"""The default episode trigger. | ||
|
||
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ... | ||
|
||
Args: | ||
episode_id: The episode number | ||
|
||
Returns: | ||
If to apply a video schedule number | ||
""" | ||
if episode_id < 1000: | ||
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id | ||
else: | ||
return episode_id % 1000 == 0 | ||
|
||
|
||
def save_video( | ||
frames: list, | ||
video_folder: str, | ||
episode_trigger: Callable[[int], bool] = None, | ||
step_trigger: Callable[[int], bool] = None, | ||
video_length: Optional[int] = None, | ||
name_prefix: str = "rl-video", | ||
episode_index: int = 0, | ||
step_starting_index: int = 0, | ||
**kwargs, | ||
): | ||
"""Save videos from rendering frames. | ||
|
||
This function extract video from a list of render frame episodes. | ||
|
||
Args: | ||
frames (List[RenderFrame]): A list of frames to compose the video. | ||
video_folder (str): The folder where the recordings will be stored | ||
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode | ||
step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step | ||
video_length (int): The length of recorded episodes. If it isn't specified, the entire episode is recorded. | ||
Otherwise, snippets of the specified length are captured. | ||
name_prefix (str): Will be prepended to the filename of the recordings. | ||
episode_index (int): The index of the current episode. | ||
step_starting_index (int): The step index of the first frame. | ||
**kwargs: The kwargs that will be passed to moviepy's ImageSequenceClip. | ||
You need to specify either fps or duration. | ||
|
||
Example: | ||
>>> import gym | ||
>>> from gym.utils.save_video import save_video | ||
>>> env = gym.make("FrozenLake-v1", render_mode="rgb_array") | ||
>>> env.reset() | ||
>>> step_starting_index = 0 | ||
>>> episode_index = 0 | ||
>>> for step_index in range(199): | ||
... action = env.action_space.sample() | ||
... _, _, done, _ = env.step(action) | ||
... if done: | ||
... save_video( | ||
... env.render(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like here only calls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, with the current Render API When you call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I am struggling to understand this. If I don’t want to render every frame, how can I achieve it with the current API? What is the old-like behavior with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want to render a single frame representing the current state, you can achieve it with env = gym.make("MyEnv", render_mode="single_rgb_array")
env.reset()
for _ in range(n_steps):
env.step(...)
env.render() # single frame
env.close() Otherwise, with env = gym.make("MyEnv", render_mode="rgb_array")
env.reset()
for _ in range(n_steps):
_, _, done, _ = env.step(...)
if done:
break
env.render() # List of all frames
env.close() |
||
... "videos", | ||
... fps=env.metadata["render_fps"], | ||
... step_starting_index=step_starting_index, | ||
... episode_index=episode_index | ||
... ) | ||
... step_starting_index = step_index + 1 | ||
... episode_index += 1 | ||
... env.reset() | ||
>>> env.close() | ||
""" | ||
if not isinstance(frames, list): | ||
logger.error( | ||
f"Expected a list of frames, got a {frames.__class__.__name__} instead." | ||
) | ||
if episode_trigger is None and step_trigger is None: | ||
episode_trigger = capped_cubic_video_schedule | ||
|
||
video_folder = os.path.abspath(video_folder) | ||
os.makedirs(video_folder, exist_ok=True) | ||
path_prefix = f"{video_folder}/{name_prefix}" | ||
|
||
if episode_trigger is not None and episode_trigger(episode_index): | ||
clip = ImageSequenceClip(frames[:video_length], **kwargs) | ||
clip.write_videofile(f"{path_prefix}-episode-{episode_index}.mp4") | ||
|
||
if step_trigger is not None: | ||
# skip the first frame since it comes from reset | ||
for step_index, frame_index in enumerate( | ||
range(1, len(frames)), start=step_starting_index | ||
): | ||
if step_trigger(step_index): | ||
end_index = ( | ||
frame_index + video_length if video_length is not None else None | ||
) | ||
clip = ImageSequenceClip(frames[frame_index:end_index], **kwargs) | ||
clip.write_videofile(f"{path_prefix}-step-{step_index}.mp4") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import os | ||
import shutil | ||
|
||
import gym | ||
from gym.utils.save_video import capped_cubic_video_schedule, save_video | ||
|
||
|
||
def test_record_video_using_default_trigger(): | ||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) | ||
|
||
env.reset() | ||
step_starting_index = 0 | ||
episode_index = 0 | ||
for step_index in range(199): | ||
action = env.action_space.sample() | ||
_, _, done, _ = env.step(action) | ||
if done: | ||
save_video( | ||
env.render(), | ||
"videos", | ||
fps=env.metadata["render_fps"], | ||
step_starting_index=step_starting_index, | ||
episode_index=episode_index, | ||
) | ||
step_starting_index = step_index + 1 | ||
episode_index += 1 | ||
env.reset() | ||
|
||
env.close() | ||
assert os.path.isdir("videos") | ||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] | ||
assert len(mp4_files) == sum( | ||
capped_cubic_video_schedule(i) for i in range(episode_index) | ||
) | ||
shutil.rmtree("videos") | ||
|
||
|
||
def modulo_step_trigger(mod: int): | ||
def step_trigger(step_index): | ||
return step_index % mod == 0 | ||
|
||
return step_trigger | ||
|
||
|
||
def test_record_video_step_trigger(): | ||
env = gym.make("CartPole-v1", render_mode="rgb_array") | ||
env._max_episode_steps = 20 | ||
|
||
env.reset() | ||
step_starting_index = 0 | ||
episode_index = 0 | ||
for step_index in range(199): | ||
action = env.action_space.sample() | ||
_, _, done, _ = env.step(action) | ||
if done: | ||
save_video( | ||
env.render(), | ||
"videos", | ||
fps=env.metadata["render_fps"], | ||
step_trigger=modulo_step_trigger(100), | ||
step_starting_index=step_starting_index, | ||
episode_index=episode_index, | ||
) | ||
step_starting_index = step_index + 1 | ||
episode_index += 1 | ||
env.reset() | ||
env.close() | ||
|
||
assert os.path.isdir("videos") | ||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] | ||
assert len(mp4_files) == 2 | ||
shutil.rmtree("videos") | ||
|
||
|
||
def test_record_video_within_vector(): | ||
envs = gym.vector.make( | ||
"CartPole-v1", num_envs=2, asynchronous=True, render_mode="rgb_array" | ||
) | ||
envs = gym.wrappers.RecordEpisodeStatistics(envs) | ||
envs.reset() | ||
episode_frames = [] | ||
step_starting_index = 0 | ||
episode_index = 0 | ||
for step_index in range(199): | ||
_, _, _, infos = envs.step(envs.action_space.sample()) | ||
episode_frames.extend(envs.call("render")[0]) | ||
|
||
if "episode" in infos and infos["_episode"][0]: | ||
save_video( | ||
episode_frames, | ||
"videos", | ||
fps=envs.metadata["render_fps"], | ||
step_trigger=modulo_step_trigger(100), | ||
step_starting_index=step_starting_index, | ||
episode_index=episode_index, | ||
) | ||
episode_frames = [] | ||
step_starting_index = step_index + 1 | ||
episode_index += 1 | ||
|
||
assert os.path.isdir("videos") | ||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] | ||
assert len(mp4_files) == 2 | ||
shutil.rmtree("videos") |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am very confused - why would
episode_trigger
still be here? It seems strange when users executesave_video
the videos might not be saved…There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the beginning, the expected type of frames was a List of episodes, where an episode is a List of frames.
Talking with @pseudo-rnd-thoughts, we agree it is simpler as it is now, with frames as a List of frames of the same episode.
For how it is right now, yes, we can drop
episode_trigger
and the user can just do it with an if statement.Notice that it still can happen that
save_video
doesn't save anything depending onstep_trigger