Skip to content

Commit

Permalink
Gym v0.26 compliance
Browse files Browse the repository at this point in the history
  • Loading branch information
JesseFarebro committed Sep 5, 2022
1 parent 093979a commit bfee9e3
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 175 deletions.
32 changes: 20 additions & 12 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
name: CI
on:
workflow_dispatch:
inputs:
debug_enabled:
type: boolean
description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)'
required: false
default: false
push:
tags-ignore:
- "*.*"
Expand Down Expand Up @@ -78,30 +85,31 @@ jobs:
with:
python-version: "3.x"

- name: Install pip dependencies
- name: Install test dependencies
# TODO(jfarebro): There's a bug with Windows cmake and PEP517 builds via pip install.
# As a temporary workaround installing cmake outside of the isolated env seems to work.
run: |
python -m pip install --user --upgrade -r tests/requirements.txt
python -m pip install --user cmake
- uses: microsoft/[email protected]
if: runner.os == 'Windows'
- uses: lukka/get-cmake@latest
- uses: lukka/run-vcpkg@v10
with:
vcpkgGitCommitId: "aebb363eaa0b658beb19cbefdd5aa2f9cbc14f1e"
# There's a permissions issue with the cache
# https://github.com/microsoft/vcpkg/issues/20121
doNotCache: true

- name: Configure
run: |
cmake --version
mkdir build && cd build
cmake ../ -DSDL_SUPPORT=ON -DPython3_EXECUTABLE=$(which python)
- name: Build
working-directory: build
run: cmake --build . --config Debug --parallel 2
run: python -m pip install --user --verbose .

- name: Test
working-directory: build
run: ctest -C Debug --progress -VV
run: python -m pytest

# Enable tmate debugging of manually-triggered workflows if the input option was provided
- name: Setup SSH debug session on failure
uses: mxschmitt/action-tmate@v3
if: ${{ github.event_name == 'workflow_dispatch' && inputs.debug_enabled && failure() }}
with:
limit-access-to-actor: true
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"numpy",
"importlib-metadata>=4.10.0; python_version < '3.10'",
"importlib-resources",
"typing-extensions; python_version < '3.11'"
]
dynamic = ["version"]

Expand All @@ -61,8 +62,8 @@ __internal__ = "ale_py.gym:register_legacy_gym_envs"
packages = [
"ale_py",
"ale_py.roms",
"ale_py.scripts",
"gym.envs.atari"
"ale_py.env",
"ale_py.scripts"
]
package-dir = {ale_py = "src/python", gym = "src/gym"}
package-data = {"ale_py" = ["py.typed"], "ale_py.roms" = ["*.bin", "md5.txt"]}
Expand All @@ -80,7 +81,9 @@ skip = ["*-win32", "*i686", "pp*", "*-musllinux*"]
build-frontend = "build"

# Test configuration
test-extras = ["test"]
# test-extras = ["test"]
# TODO(jfarebro): Temporarily use upstream Gym until v26 release.
test-requires = ["pytest", "git+https://github.com/openai/gym#egg=gym"]
test-command = "pytest {project}"

# vcpkg manylinux images
Expand Down
25 changes: 13 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import re
import sys
import subprocess
import sys

from setuptools import setup, Extension
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext

here = os.path.abspath(os.path.dirname(__file__))
Expand Down Expand Up @@ -53,11 +53,15 @@ def build_extension(self, ext):
# exported for Ninja to pick it up, which is a little tricky to do.
# Users can override the generator with CMAKE_GENERATOR in CMake
# 3.15+.
if not cmake_generator:
if not cmake_generator or cmake_generator == "Ninja":
try:
import ninja # noqa: F401

cmake_args += ["-GNinja"]
ninja_executable_path = os.path.join(ninja.BIN_DIR, "ninja")
cmake_args += [
"-GNinja",
f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
]
except ImportError:
pass

Expand Down Expand Up @@ -95,15 +99,12 @@ def build_extension(self, ext):
if hasattr(self, "parallel") and self.parallel:
build_args += [f"-j{self.parallel}"]

if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
build_temp = os.path.join(self.build_temp, ext.name)
if not os.path.exists(build_temp):
os.makedirs(build_temp)

subprocess.check_call(
["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp
)
subprocess.check_call(
["cmake", "--build", "."] + build_args, cwd=self.build_temp
)
subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp)


def parse_version(version_file):
Expand Down
1 change: 0 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ endif()
# Python Library
if (BUILD_PYTHON_LIB)
add_subdirectory(python)
add_subdirectory(gym)
endif()

# Install cpp-lib and CMake config & pkg-config
Expand Down
7 changes: 0 additions & 7 deletions src/gym/CMakeLists.txt

This file was deleted.

1 change: 0 additions & 1 deletion src/gym/envs/atari/__init__.py

This file was deleted.

Empty file added src/python/env/__init__.py
Empty file.
89 changes: 46 additions & 43 deletions src/gym/envs/atari/environment.py → src/python/env/gym.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import sys
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import ale_py
import ale_py.roms as roms
import ale_py.roms.utils as rom_utils
import numpy as np

import gym
import gym.logger as logger
import numpy as np
from gym import error, spaces, utils

if sys.version_info < (3, 11):
from typing_extensions import NotRequired, TypedDict
else:
from typing import NotRequired, TypedDict


class AtariEnvStepMetadata(TypedDict):
lives: int
episode_frame_number: int
frame_number: int
seeds: NotRequired[Sequence[int]]


class AtariEnv(gym.Env, utils.EzPickle):
"""
Expand All @@ -27,6 +41,7 @@ def __init__(
frameskip: Union[Tuple[int, int], int] = 4,
repeat_action_probability: float = 0.25,
full_action_space: bool = False,
max_num_frames_per_episode: Optional[int] = None,
render_mode: Optional[str] = None,
) -> None:
"""
Expand All @@ -43,6 +58,9 @@ def __init__(
repeat_action_probability: int =>
Probability to repeat actions, see Machado et al., 2018
full_action_space: bool => Use full action space?
max_num_frames_per_episode: int => Max number of frame per epsiode.
Once `max_num_frames_per_episode` is reached the episode is
truncated.
render_mode: str => One of { 'human', 'rgb_array' }.
If `human` we'll interactively display the screen and enable
game sounds. This will lock emulation to the ROMs specified FPS
Expand Down Expand Up @@ -104,6 +122,7 @@ def __init__(
frameskip,
repeat_action_probability,
full_action_space,
max_num_frames_per_episode,
render_mode,
)

Expand All @@ -124,6 +143,9 @@ def __init__(
# Config sticky action prob.
self.ale.setFloat("repeat_action_probability", repeat_action_probability)

if max_num_frames_per_episode is not None:
self.ale.setInt("max_num_frames_per_episode", max_num_frames_per_episode)

# If render mode is human we can display screen and sound
if render_mode == "human":
self.ale.setBool("display_screen", True)
Expand Down Expand Up @@ -197,12 +219,12 @@ def seed(self, seed: Optional[int] = None) -> Tuple[int, int]:
if self._game_difficulty is not None:
self.ale.setDifficulty(self._game_difficulty)

return (
seed1,
seed2,
)
return seed1, seed2

def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
def step(
self,
action_ind: int,
) -> Tuple[np.ndarray, float, bool, bool, AtariEnvStepMetadata]:
"""
Perform one agent step, i.e., repeats `action` frameskip # of steps.
Expand Down Expand Up @@ -232,20 +254,21 @@ def step(self, action_ind: int) -> Tuple[np.ndarray, float, bool, Dict[str, Any]
reward = 0.0
for _ in range(frameskip):
reward += self.ale.act(action)
terminal = self.ale.game_over()
is_terminal = self.ale.game_over(with_truncation=False)
is_truncated = self.ale.game_truncated()

return self._get_obs(), reward, terminal, self._get_info()
return self._get_obs(), reward, is_terminal, is_truncated, self._get_info()

def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[Dict[str, Any]] = None,
) -> Union[Tuple[np.ndarray, Dict[str, Any]], np.ndarray]:
) -> Tuple[np.ndarray, AtariEnvStepMetadata]:
"""
Resets environment and returns initial observation.
"""
del options
# Gym's new seeding API seeds on reset.
# This will cause the console to be recreated
# and loose all previous state, e.g., statistics, etc.
Expand All @@ -256,15 +279,12 @@ def reset(
self.ale.reset_game()
obs = self._get_obs()

if return_info:
info = self._get_info()
if seeded_with is not None:
info["seeds"] = seeded_with
return obs, info
else:
return obs
info = self._get_info()
if seeded_with is not None:
info["seeds"] = seeded_with
return obs, info

def render(self, mode: str) -> np.ndarray:
def render(self) -> Any:
"""
Render is not supported by ALE. We use a paradigm similar to
Gym3 which allows you to specify `render_mode` during construction.
Expand All @@ -274,28 +294,16 @@ def render(self, mode: str) -> np.ndarray:
will display the ALE and maintain the proper interval to match the
FPS target set by the ROM.
"""
if mode == "rgb_array":
if self._render_mode == "rgb_array":
return self.ale.getScreenRGB()
elif mode == "human":
raise error.Error(
(
"render(mode='human') is deprecated. Please supply `render_mode` when "
"constructing your environment, e.g., gym.make(ID, render_mode='human'). "
"The new `render_mode` keyword argument supports DPI scaling, "
"audio, and native framerates."
)
)
elif self._render_mode == "human":
pass
else:
raise error.Error(
f"Invalid render mode `{mode}`. Supported modes: `rgb_array`."
f"Invalid render mode `{self._render_mode}`. "
"Supported modes: `human`, `rgb_array`."
)

def close(self) -> None:
"""
Cleanup any leftovers by the environment
"""
pass

def _get_obs(self) -> np.ndarray:
"""
Retreives the current observation.
Expand All @@ -310,18 +318,13 @@ def _get_obs(self) -> np.ndarray:
else:
raise error.Error(f"Unrecognized observation type: {self._obs_type}")

def _get_info(self) -> Dict[str, Any]:
info = {
def _get_info(self) -> AtariEnvStepMetadata:
return {
"lives": self.ale.lives(),
"episode_frame_number": self.ale.getEpisodeFrameNumber(),
"frame_number": self.ale.getFrameNumber(),
}

if self._render_mode == "rgb_array":
info["rgb"] = self.ale.getScreenRGB()

return info

def get_keys_to_action(self) -> Dict[Tuple[int], ale_py.Action]:
"""
Return keymapping -> actions for human play.
Expand Down
Loading

0 comments on commit bfee9e3

Please sign in to comment.