Skip to content

Commit

Permalink
Check that environment is not done on observation request.
Browse files Browse the repository at this point in the history
This patch improves the error reporting when computing an observation
fails. First, if the service produces an unexpected number of
observations, a ServiceError is raised, rather than the previous
assertion. Second, if the environment reports that it has reached a
terminal state, a ServiceError is raised, containing the error details
produced by the environment.
  • Loading branch information
ChrisCummins committed Sep 8, 2021
1 parent 32c8294 commit cf45a24
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 4 deletions.
24 changes: 20 additions & 4 deletions compiler_gym/views/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, List

from compiler_gym.service.connection import ServiceError
from compiler_gym.service.proto import ObservationSpace
from compiler_gym.util.gym_type_hints import (
ActionType,
Expand Down Expand Up @@ -58,14 +59,29 @@ def __getitem__(self, observation_space: str) -> ObservationType:
:raises SessionNotFound: If :meth:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>` has not been called.
:raises ServiceError: If the backend service fails to compute the
observation, or reports that a terminal state has been reached.
"""
observation_space: ObservationSpaceSpec = self.spaces[observation_space]
observations, _, _, _ = self._raw_step(
observations, _, done, info = self._raw_step(
actions=[], observations=[observation_space], rewards=[]
)
assert (
len(observations) == 1
), f"Expected 1 observation. Received: {len(observations)}"

if done:
# Computing an observation should never cause a terminal state since
# no action has been applied.
msg = f"Failed to compute observation '{observation_space.id}'"
if info.get("error_details"):
msg += f": {info['error_details']}"
raise ServiceError(msg)

if len(observations) != 1:
raise ServiceError(
f"Expected 1 '{observation_space.id}' observation "
f"but the service returned {len(observations)}"
)

return observations[0]

def _add_space(self, space: ObservationSpaceSpec):
Expand Down
65 changes: 65 additions & 0 deletions tests/views/observation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest

from compiler_gym.service.connection import ServiceError
from compiler_gym.service.proto import (
ObservationSpace,
ScalarLimit,
Expand Down Expand Up @@ -122,5 +123,69 @@ def test_observed_value_types():
assert mock.called_observation_spaces == ["ir", "dfeat", "features", "binary"]


def test_observation_when_raw_step_returns_incorrect_no_of_observations():
"""Test that a ServiceError is propagated when raw_step() returns unexpected
number of observations."""

def make_failing_raw_step(n: int):
def failing_raw_step(*args, **kwargs):
"""A callback that returns done=True."""
del args # Unused
del kwargs # Unused
return ["ir"] * n, None, False, {}

return failing_raw_step

spaces = [
ObservationSpace(
name="ir",
string_size_range=ScalarRange(min=ScalarLimit(value=0)),
)
]

observation = ObservationView(make_failing_raw_step(0), spaces)
with pytest.raises(
ServiceError, match=r"^Expected 1 'ir' observation but the service returned 0$"
):
observation["ir"]

observation = ObservationView(make_failing_raw_step(3), spaces)
with pytest.raises(
ServiceError, match=r"^Expected 1 'ir' observation but the service returned 3$"
):
observation["ir"]


def test_observation_when_raw_step_returns_done():
"""Test that a SessionNotFoundError from the raw_step() callback propagates as a """

def make_failing_raw_step(error_msg=None):
def failing_raw_step(*args, **kwargs):
"""A callback that returns done=True."""
info = {}
if error_msg:
info["error_details"] = error_msg
return [], None, True, info

return failing_raw_step

spaces = [
ObservationSpace(
name="ir",
string_size_range=ScalarRange(min=ScalarLimit(value=0)),
)
]

observation = ObservationView(make_failing_raw_step(), spaces)
with pytest.raises(ServiceError, match=r"^Failed to compute observation 'ir'$"):
observation["ir"] # pylint: disable=pointless-statement

observation = ObservationView(make_failing_raw_step("Oh no!"), spaces)
with pytest.raises(
ServiceError, match=r"^Failed to compute observation 'ir': Oh no!$"
):
observation["ir"] # pylint: disable=pointless-statement


if __name__ == "__main__":
main()

0 comments on commit cf45a24

Please sign in to comment.