Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split CompilerEnv.step() into two methods for singular or lists of actions (take 2) #627

Merged
merged 8 commits into from
Mar 17, 2022
11 changes: 8 additions & 3 deletions compiler_gym/bin/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
from compiler_gym.datasets import Dataset
from compiler_gym.envs import CompilerEnv
from compiler_gym.service.connection import ConnectionOpts
from compiler_gym.spaces import Commandline
from compiler_gym.spaces import Commandline, NamedDiscrete
from compiler_gym.util.flags.env_from_flags import env_from_flags
from compiler_gym.util.tabulate import tabulate
from compiler_gym.util.truncate import truncate
Expand Down Expand Up @@ -249,12 +249,17 @@ def print_service_capabilities(env: CompilerEnv):
],
headers=("Action", "Description"),
)
else:
print(table)
elif isinstance(action_space, NamedDiscrete):
table = tabulate(
[(a,) for a in sorted(action_space.names)],
headers=("Action",),
)
print(table)
print(table)
else:
raise NotImplementedError(
"Only Commandline and NamedDiscrete are supported."
)


def main(argv):
Expand Down
159 changes: 114 additions & 45 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class CompilerEnv(gym.Env):
:ivar actions: The list of actions that have been performed since the
previous call to :func:`reset`.

:vartype actions: List[int]
:vartype actions: List[ActionType]

:ivar reward_range: A tuple indicating the range of reward values. Default
range is (-inf, +inf).
Expand Down Expand Up @@ -321,7 +321,7 @@ def __init__(
self.reward_range: Tuple[float, float] = (-np.inf, np.inf)
self.episode_reward: Optional[float] = None
self.episode_start_time: float = time()
self.actions: List[int] = []
self.actions: List[ActionType] = []

# Initialize the default observation/reward spaces.
self.observation_space_spec: Optional[ObservationSpaceSpec] = None
Expand Down Expand Up @@ -375,7 +375,7 @@ def commandline(self) -> str:
"""
raise NotImplementedError("abstract method")

def commandline_to_actions(self, commandline: str) -> List[int]:
def commandline_to_actions(self, commandline: str) -> List[ActionType]:
"""Interface for :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`
subclasses to convert from a commandline invocation to a sequence of
actions.
Expand Down Expand Up @@ -409,7 +409,7 @@ def state(self) -> CompilerEnvState:
)

@property
def action_space(self) -> NamedDiscrete:
def action_space(self) -> Space:
"""The current action space.

:getter: Get the current action space.
Expand Down Expand Up @@ -587,7 +587,7 @@ def fork(self) -> "CompilerEnv":
self.reset()
if actions:
logger.warning("Parent service of fork() has died, replaying state")
_, _, done, _ = self.step(actions)
_, _, done, _ = self.multistep(actions)
assert not done, "Failed to replay action sequence"

request = ForkSessionRequest(session_id=self._session_id)
Expand Down Expand Up @@ -620,7 +620,7 @@ def fork(self) -> "CompilerEnv":
# replay the state.
new_env = type(self)(**self._init_kwargs())
new_env.reset()
_, _, done, _ = new_env.step(self.actions)
_, _, done, _ = new_env.multistep(self.actions)
assert not done, "Failed to replay action sequence in forked environment"

# Create copies of the mutable reward and observation spaces. This
Expand Down Expand Up @@ -885,9 +885,9 @@ def _call_with_error(

def raw_step(
self,
actions: Iterable[int],
observations: Iterable[ObservationSpaceSpec],
rewards: Iterable[Reward],
actions: Iterable[ActionType],
observation_spaces: List[ObservationSpaceSpec],
reward_spaces: List[Reward],
) -> StepType:
"""Take a step.

Expand All @@ -908,26 +908,23 @@ def raw_step(

.. warning::

Prefer :meth:`step() <compiler_gym.envs.CompilerEnv.step>` to
:meth:`raw_step() <compiler_gym.envs.CompilerEnv.step>`.
:meth:`step() <compiler_gym.envs.CompilerEnv.step>` has equivalent
functionality, and is less likely to change in the future.
Don't call this method directly, use :meth:`step()
<compiler_gym.envs.CompilerEnv.step>` or :meth:`multistep()
<compiler_gym.envs.CompilerEnv.multistep>` instead. The
:meth:`raw_step() <compiler_gym.envs.CompilerEnv.step>` method is an
implementation detail.
"""
if not self.in_episode:
raise SessionNotFound("Must call reset() before step()")

# Build the list of observations that must be computed by the backend
user_observation_spaces: List[ObservationSpaceSpec] = list(observations)
reward_spaces: List[Reward] = list(rewards)

reward_observation_spaces: List[ObservationSpaceSpec] = []
for reward_space in reward_spaces:
reward_observation_spaces += [
self.observation.spaces[obs] for obs in reward_space.observation_spaces
]

observations_to_compute: List[ObservationSpaceSpec] = list(
set(user_observation_spaces).union(set(reward_observation_spaces))
set(observation_spaces).union(set(reward_observation_spaces))
)
observation_space_index_map: Dict[ObservationSpaceSpec, int] = {
observation_space: i
Expand Down Expand Up @@ -974,7 +971,7 @@ def raw_step(

default_observations = [
observation_space.default_value
for observation_space in user_observation_spaces
for observation_space in observation_spaces
]
default_rewards = [
float(reward_space.reward_on_error(self.episode_reward))
Expand Down Expand Up @@ -1002,7 +999,7 @@ def raw_step(
# Get the user-requested observation.
observations: List[ObservationType] = [
computed_observations[observation_space_index_map[observation_space]]
for observation_space in user_observation_spaces
for observation_space in observation_spaces
]

# Update and compute the rewards.
Expand All @@ -1029,25 +1026,83 @@ def raw_step(

return observations, rewards, reply.end_of_session, info

def step(
def step( # pylint: disable=arguments-differ
self,
action: Union[ActionType, Iterable[ActionType]],
action: ActionType,
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
) -> StepType:
"""Take a step.

:param action: An action, or a sequence of actions. When multiple
actions are provided the observation and reward are returned after
running all of the actions.
:param action: An action.

:param observation_spaces: A list of observation spaces to compute
observations from. If provided, this changes the :code:`observation`
element of the return tuple to be a list of observations from the
requested spaces. The default :code:`env.observation_space` is not
returned.

:param reward_spaces: A list of reward spaces to compute rewards from. If
provided, this changes the :code:`reward` element of the return
tuple to be a list of rewards from the requested spaces. The default
:code:`env.reward_space` is not returned.

:return: A tuple of observation, reward, done, and info. Observation and
reward are None if default observation/reward is not set.

:raises SessionNotFound: If :meth:`reset()
<compiler_gym.envs.CompilerEnv.reset>` has not been called.
"""
if isinstance(action, IterableType):
warnings.warn(
"Argument `action` of CompilerEnv.step no longer accepts a list "
" of actions. Please use CompilerEnv.multistep instead",
category=DeprecationWarning,
)
return self.multistep(
action,
observation_spaces=observation_spaces,
reward_spaces=reward_spaces,
observations=observations,
rewards=rewards,
)
if observations is not None:
warnings.warn(
"Argument `observations` of CompilerEnv.step has been "
"renamed `observation_spaces`. Please update your code",
category=DeprecationWarning,
)
observation_spaces = observations
if rewards is not None:
warnings.warn(
"Argument `rewards` of CompilerEnv.step has been renamed "
"`reward_spaces`. Please update your code",
category=DeprecationWarning,
)
reward_spaces = rewards
return self.multistep([action], observation_spaces, reward_spaces)

def multistep(
self,
actions: Iterable[ActionType],
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
):
"""Take a sequence of steps and return the final observation and reward.

:param action: A sequence of actions to apply in order.

:param observations: A list of observation spaces to compute
:param observation_spaces: A list of observation spaces to compute
observations from. If provided, this changes the :code:`observation`
element of the return tuple to be a list of observations from the
requested spaces. The default :code:`env.observation_space` is not
returned.

:param rewards: A list of reward spaces to compute rewards from. If
:param reward_spaces: A list of reward spaces to compute rewards from. If
provided, this changes the :code:`reward` element of the return
tuple to be a list of rewards from the requested spaces. The default
:code:`env.reward_space` is not returned.
Expand All @@ -1058,52 +1113,64 @@ def step(
:raises SessionNotFound: If :meth:`reset()
<compiler_gym.envs.CompilerEnv.reset>` has not been called.
"""
# Coerce actions into a list.
actions = action if isinstance(action, IterableType) else [action]
if observations is not None:
warnings.warn(
"Argument `observations` of CompilerEnv.multistep has been "
"renamed `observation_spaces`. Please update your code",
category=DeprecationWarning,
)
observation_spaces = observations
if rewards is not None:
warnings.warn(
"Argument `rewards` of CompilerEnv.multistep has been renamed "
"`reward_spaces`. Please update your code",
category=DeprecationWarning,
)
reward_spaces = rewards

# Coerce observation spaces into a list of ObservationSpaceSpec instances.
if observations:
observation_spaces: List[ObservationSpaceSpec] = [
if observation_spaces:
observation_spaces_to_compute: List[ObservationSpaceSpec] = [
obs
if isinstance(obs, ObservationSpaceSpec)
else self.observation.spaces[obs]
for obs in observations
for obs in observation_spaces
]
elif self.observation_space_spec:
observation_spaces: List[ObservationSpaceSpec] = [
observation_spaces_to_compute: List[ObservationSpaceSpec] = [
self.observation_space_spec
]
else:
observation_spaces: List[ObservationSpaceSpec] = []
observation_spaces_to_compute: List[ObservationSpaceSpec] = []

# Coerce reward spaces into a list of Reward instances.
if rewards:
reward_spaces: List[Reward] = [
if reward_spaces:
reward_spaces_to_compute: List[Reward] = [
rew if isinstance(rew, Reward) else self.reward.spaces[rew]
for rew in rewards
for rew in reward_spaces
]
elif self.reward_space:
reward_spaces: List[Reward] = [self.reward_space]
reward_spaces_to_compute: List[Reward] = [self.reward_space]
else:
reward_spaces: List[Reward] = []
reward_spaces_to_compute: List[Reward] = []

# Perform the underlying environment step.
observation_values, reward_values, done, info = self.raw_step(
actions, observation_spaces, reward_spaces
actions, observation_spaces_to_compute, reward_spaces_to_compute
)

# Translate observations lists back to the appropriate types.
if observations is None and self.observation_space_spec:
if observation_spaces is None and self.observation_space_spec:
observation_values = observation_values[0]
elif not observation_spaces:
elif not observation_spaces_to_compute:
observation_values = None

# Translate reward lists back to the appropriate types.
if rewards is None and self.reward_space:
if reward_spaces is None and self.reward_space:
reward_values = reward_values[0]
# Update the cumulative episode reward
self.episode_reward += reward_values
elif not reward_spaces:
elif not reward_spaces_to_compute:
reward_values = None

return observation_values, reward_values, done, info
Expand Down Expand Up @@ -1176,7 +1243,9 @@ def apply(self, state: CompilerEnvState) -> None: # noqa
)

actions = self.commandline_to_actions(state.commandline)
_, _, done, info = self.step(actions)
done = False
for action in actions:
_, _, done, info = self.step(action)
if done:
raise ValueError(
f"Environment terminated with error: `{info.get('error_details')}`"
Expand Down
6 changes: 3 additions & 3 deletions compiler_gym/envs/llvm/llvm_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from compiler_gym.datasets import Benchmark
from compiler_gym.spaces.reward import Reward
from compiler_gym.util.gym_type_hints import ObservationType, RewardType
from compiler_gym.util.gym_type_hints import ActionType, ObservationType, RewardType
from compiler_gym.views.observation import ObservationView


Expand Down Expand Up @@ -44,7 +44,7 @@ def reset(self, benchmark: Benchmark, observation_view: ObservationView) -> None

def update(
self,
actions: List[int],
actions: List[ActionType],
observations: List[ObservationType],
observation_view: ObservationView,
) -> RewardType:
Expand Down Expand Up @@ -81,7 +81,7 @@ def reset(self, benchmark: str, observation_view: ObservationView) -> None:

def update(
self,
actions: List[int],
actions: List[ActionType],
observations: List[ObservationType],
observation_view: ObservationView,
) -> RewardType:
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/random_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


@deprecated(version="0.2.1", reason="Use env.step(actions) instead")
@deprecated(version="0.2.1", reason="Use env.step(action) instead")
def replay_actions(env: CompilerEnv, action_names: List[str], outdir: Path):
return replay_actions_(env, action_names, outdir)

Expand Down
Loading