Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CompilerEnv.observation_space a gym.Space. #228

Merged
merged 2 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compiler_gym/bin/manual_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def do_set_benchmark(self, arg):

if self.env.observation_space and observation is not None:
print(
f"Observation: {self.env.observation_space.to_string(observation)}"
f"Observation: {self.env.observation_space_spec.to_string(observation)}"
)

self.set_prompt()
Expand Down Expand Up @@ -486,7 +486,7 @@ def do_action(self, arg):
# Print the observation, if available.
if self.env.observation_space and observation is not None:
print(
f"Observation: {self.env.observation_space.to_string(observation)}"
f"Observation: {self.env.observation_space_spec.to_string(observation)}"
)

# Print the reward, if available.
Expand Down Expand Up @@ -684,7 +684,7 @@ def do_observation(self, arg):
return

if arg == "" and self.env.observation_space:
arg = self.env.observation_space.id
arg = self.env.observation_space_spec.id

if self.observations.count(arg):
with Timer() as timer:
Expand Down
68 changes: 33 additions & 35 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,18 @@ def __init__(
# Lazily evaluated version strings.
self._versions: Optional[GetVersionReply] = None

# Mutable state initialized in reset().
self.action_space: Optional[Space] = None
self.observation_space: Optional[Space] = None

# Mutable state initialized in reset().
self.reward_range: Tuple[float, float] = (-np.inf, np.inf)
self.episode_reward: Optional[float] = None
self.episode_start_time: float = time()
self.actions: List[int] = []

# Initialize the default observation/reward spaces.
self._default_observation_space: Optional[ObservationSpaceSpec] = None
self._default_reward_space: Optional[Reward] = None
self.observation_space_spec: Optional[ObservationSpaceSpec] = None
self.reward_space_spec: Optional[Reward] = None
self.observation_space = observation_space
self.reward_space = reward_space

Expand Down Expand Up @@ -461,26 +462,27 @@ def reward_space(self) -> Optional[Reward]:
or :code:`None` if not set.
:setter: Set the default reward space.
"""
return self._default_reward_space
return self.reward_space_spec

@reward_space.setter
def reward_space(self, reward_space: Optional[Union[str, Reward]]) -> None:
if isinstance(reward_space, str) and reward_space not in self.reward.spaces:
raise LookupError(f"Reward space not found: {reward_space}")

reward_space_name = (
# Coerce the observation space into a string.
reward_space: Optional[str] = (
reward_space.id if isinstance(reward_space, Reward) else reward_space
)

self._default_reward: bool = reward_space is not None
self._default_reward_space: Optional[Reward] = None
if self._default_reward:
self._default_reward_space = self.reward.spaces[reward_space_name]
if reward_space:
if reward_space not in self.reward.spaces:
raise LookupError(f"Reward space not found: {reward_space}")
self.reward_space_spec = self.reward.spaces[reward_space]
self.reward_range = (
self._default_reward_space.min,
self._default_reward_space.max,
self.reward_space_spec.min,
self.reward_space_spec.max,
)
else:
# If no reward space is being used then set the reward range to
# unbounded.
self.reward_space_spec = None
self.reward_range = (-np.inf, np.inf)

@property
Expand All @@ -501,30 +503,26 @@ def observation_space(self) -> Optional[ObservationSpaceSpec]:
:code:`None` if not set.
:setter: Set the default observation space.
"""
return self._default_observation_space
if self.observation_space_spec:
return self.observation_space_spec.space

@observation_space.setter
def observation_space(
self, observation_space: Optional[Union[str, ObservationSpaceSpec]]
) -> None:
if (
isinstance(observation_space, str)
and observation_space not in self.observation.spaces
):
raise LookupError(f"Observation space not found: {observation_space}")

observation_space_name = (
# Coerce the observation space into a string.
observation_space: Optional[str] = (
observation_space.id
if isinstance(observation_space, ObservationSpaceSpec)
else observation_space
)

self._default_observation = observation_space is not None
self._default_observation_space: Optional[ObservationSpaceSpec] = None
if self._default_observation:
self._default_observation_space = self.observation.spaces[
observation_space_name
]
if observation_space:
if observation_space not in self.observation.spaces:
raise LookupError(f"Observation space not found: {observation_space}")
self.observation_space_spec = self.observation.spaces[observation_space]
else:
self.observation_space_spec = None

def fork(self) -> "CompilerEnv":
"""Fork a new environment with exactly the same state.
Expand Down Expand Up @@ -603,7 +601,7 @@ def fork(self) -> "CompilerEnv":
# Set the default observation and reward types. Note the use of IDs here
# to prevent passing the spaces by reference.
if self.observation_space:
new_env.observation_space = self.observation_space.id
new_env.observation_space = self.observation_space_spec.id
if self.reward_space:
new_env.reward_space = self.reward_space.id

Expand Down Expand Up @@ -705,7 +703,7 @@ def reset( # pylint: disable=arguments-differ
else 0
),
observation_space=(
[self.observation_space.index]
[self.observation_space_spec.index]
if self.observation_space
else None
),
Expand Down Expand Up @@ -752,7 +750,7 @@ def reset( # pylint: disable=arguments-differ
raise OSError(
f"Expected one observation from service, received {len(reply.observation)}"
)
return self.observation.spaces[self.observation_space.id].translate(
return self.observation.spaces[self.observation_space_spec.id].translate(
reply.observation[0]
)

Expand All @@ -778,8 +776,8 @@ def step(self, action: Union[int, Iterable[int]]) -> step_t:
# requested.
observation_indices, observation_spaces = [], []
if self.observation_space:
observation_indices.append(self.observation_space.index)
observation_spaces.append(self.observation_space.id)
observation_indices.append(self.observation_space_spec.index)
observation_spaces.append(self.observation_space_spec.id)
if self.reward_space:
observation_indices += [
self.observation.spaces[obs].index
Expand Down Expand Up @@ -816,7 +814,7 @@ def step(self, action: Union[int, Iterable[int]]) -> step_t:
if self.reward_space:
reward = self.reward_space.reward_on_error(self.episode_reward)
if self.observation_space:
observation = self.observation_space.default_value
observation = self.observation_space_spec.default_value
return observation, reward, True, info

# If the action space has changed, update it.
Expand Down Expand Up @@ -870,7 +868,7 @@ def render(
"""
if not self.observation_space:
raise ValueError("Cannot call render() when no observation space is used")
observation = self.observation[self.observation_space.id]
observation = self.observation[self.observation_space_spec.id]
if mode == "human":
print(observation)
elif mode == "ansi":
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ The observation space is described by
The :ref:`Autophase <llvm/index:Autophase>` observation space is a 56-dimension
vector of integers:

>>> env.observation_space.space.shape
>>> env.observation_space.shape
(56,)
>>> env.observation_space.space.dtype
>>> env.observation_space.dtype
dtype('int64')

The upper and lower bounds of the reward signal are described by
Expand Down
4 changes: 2 additions & 2 deletions examples/getting-started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@
"metadata": {},
"outputs": [],
"source": [
"env.observation_space.space.shape"
"env.observation_space.shape"
]
},
{
Expand All @@ -247,7 +247,7 @@
"metadata": {},
"outputs": [],
"source": [
"env.observation_space.space.dtype"
"env.observation_space.dtype"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_random_walk(env: CompilerEnv, step_count: int) -> None:
rewards.append(reward)
actions.append(env.action_space.names[action_index])
print(f"Reward: {reward}")
if env._default_observation:
if env.observation_space:
print(f"Observation:\n{observation}")
print(f"Step time: {step_time}")
if done:
Expand Down
10 changes: 10 additions & 0 deletions tests/llvm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ py_test(
],
)

py_test(
name = "gym_interface_compatability",
timeout = "short",
srcs = ["gym_interface_compatability.py"],
deps = [
"//compiler_gym",
"//tests:test_main",
],
)

py_test(
name = "llvm_benchmarks_test",
srcs = ["llvm_benchmarks_test.py"],
Expand Down
81 changes: 81 additions & 0 deletions tests/llvm/gym_interface_compatability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Test that LlvmEnv is compatible with OpenAI gym interface."""
import gym
import pytest

import compiler_gym # noqa Register Environments
from compiler_gym.envs import CompilerEnv
from tests.test_main import main


@pytest.fixture(scope="function")
def env() -> CompilerEnv:
env = gym.make("llvm-autophase-ic-v0")
try:
yield env
finally:
env.close()


def test_type_classes(env: CompilerEnv):
assert isinstance(env, gym.Env)
assert isinstance(env, CompilerEnv)
assert isinstance(env.unwrapped, CompilerEnv)
assert isinstance(env.action_space, gym.Space)
assert isinstance(env.observation_space, gym.Space)
assert isinstance(env.reward_range[0], float)
assert isinstance(env.reward_range[1], float)


def test_optional_properties(env: CompilerEnv):
assert "render.modes" in env.metadata
assert env.spec


def test_contextmanager(env: CompilerEnv, mocker):
mocker.spy(env, "close")
assert env.close.call_count == 0
with env:
pass
assert env.close.call_count == 1


def test_contextmanager_gym_make(mocker):
with gym.make("llvm-v0") as env:
mocker.spy(env, "close")
assert env.close.call_count == 0
with env:
pass
assert env.close.call_count == 1


def test_observation_wrapper(env: CompilerEnv):
class WrappedEnv(gym.ObservationWrapper):
def observation(self, observation):
return "Hello"

wrapped = WrappedEnv(env)
observation = wrapped.reset()
assert observation == "Hello"

observation, _, _, _ = wrapped.step(0)
assert observation == "Hello"


def test_reward_wrapper(env: CompilerEnv):
class WrappedEnv(gym.RewardWrapper):
def reward(self, reward):
return 1

wrapped = WrappedEnv(env)
wrapped.reset()

_, reward, _, _ = wrapped.step(0)
assert reward == 1


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion tests/llvm/llvm_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_gym_make_kwargs():
"llvm-v0", observation_space="Autophase", reward_space="IrInstructionCount"
)
try:
assert env.observation_space.id == "Autophase"
assert env.observation_space_spec.id == "Autophase"
assert env.reward_space.id == "IrInstructionCount"
finally:
env.close()
Expand Down
7 changes: 4 additions & 3 deletions tests/llvm/observation_spaces_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@

def test_default_observation_space(env: LlvmEnv):
env.observation_space = "Autophase"
assert env.observation_space.id == "Autophase"
assert env.observation_space.shape == (56,)
assert env.observation_space_spec.id == "Autophase"

env.observation_space = None
assert env.observation_space is None
assert env.observation_space_spec is None

invalid = "invalid value"
with pytest.raises(LookupError) as ctx:
with pytest.raises(LookupError, match=f"Observation space not found: {invalid}"):
env.observation_space = invalid
assert str(ctx.value) == f"Observation space not found: {invalid}"


def test_observation_spaces(env: LlvmEnv):
Expand Down