Skip to content

Commit

Permalink
Merge pull request #380 from ChrisCummins/observation-error-message
Browse files Browse the repository at this point in the history
Improved error reporting from ObservationView.__getitem__().
  • Loading branch information
ChrisCummins authored Sep 8, 2021
2 parents c824115 + cf45a24 commit 5354ddd
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 8 deletions.
24 changes: 20 additions & 4 deletions compiler_gym/views/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, List

from compiler_gym.service.connection import ServiceError
from compiler_gym.service.proto import ObservationSpace
from compiler_gym.util.gym_type_hints import (
ActionType,
Expand Down Expand Up @@ -58,14 +59,29 @@ def __getitem__(self, observation_space: str) -> ObservationType:
:raises SessionNotFound: If :meth:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>` has not been called.
:raises ServiceError: If the backend service fails to compute the
observation, or reports that a terminal state has been reached.
"""
observation_space: ObservationSpaceSpec = self.spaces[observation_space]
observations, _, _, _ = self._raw_step(
observations, _, done, info = self._raw_step(
actions=[], observations=[observation_space], rewards=[]
)
assert (
len(observations) == 1
), f"Expected 1 observation. Received: {len(observations)}"

if done:
# Computing an observation should never cause a terminal state since
# no action has been applied.
msg = f"Failed to compute observation '{observation_space.id}'"
if info.get("error_details"):
msg += f": {info['error_details']}"
raise ServiceError(msg)

if len(observations) != 1:
raise ServiceError(
f"Expected 1 '{observation_space.id}' observation "
f"but the service returned {len(observations)}"
)

return observations[0]

def _add_space(self, space: ObservationSpaceSpec):
Expand Down
73 changes: 69 additions & 4 deletions tests/views/observation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest

from compiler_gym.service.connection import ServiceError
from compiler_gym.service.proto import (
ObservationSpace,
ScalarLimit,
Expand All @@ -16,8 +17,8 @@
from tests.test_main import main


class MockGetObservation:
"""Mock for the get_observation callack of ObservationView."""
class MockRawStep:
"""Mock for the raw_step callack of ObservationView."""

def __init__(self, ret=None):
self.called_observation_spaces = []
Expand All @@ -35,7 +36,7 @@ def __call__(self, actions, observations, rewards):

def test_empty_space():
with pytest.raises(ValueError) as ctx:
ObservationView(MockGetObservation(), [])
ObservationView(MockRawStep(), [])
assert str(ctx.value) == "No observation spaces"


Expand Down Expand Up @@ -73,7 +74,7 @@ def test_observed_value_types():
),
),
]
mock = MockGetObservation(
mock = MockRawStep(
ret=[
"Hello, IR",
[1.0, 2.0],
Expand Down Expand Up @@ -122,5 +123,69 @@ def test_observed_value_types():
assert mock.called_observation_spaces == ["ir", "dfeat", "features", "binary"]


def test_observation_when_raw_step_returns_incorrect_no_of_observations():
"""Test that a ServiceError is propagated when raw_step() returns unexpected
number of observations."""

def make_failing_raw_step(n: int):
def failing_raw_step(*args, **kwargs):
"""A callback that returns done=True."""
del args # Unused
del kwargs # Unused
return ["ir"] * n, None, False, {}

return failing_raw_step

spaces = [
ObservationSpace(
name="ir",
string_size_range=ScalarRange(min=ScalarLimit(value=0)),
)
]

observation = ObservationView(make_failing_raw_step(0), spaces)
with pytest.raises(
ServiceError, match=r"^Expected 1 'ir' observation but the service returned 0$"
):
observation["ir"]

observation = ObservationView(make_failing_raw_step(3), spaces)
with pytest.raises(
ServiceError, match=r"^Expected 1 'ir' observation but the service returned 3$"
):
observation["ir"]


def test_observation_when_raw_step_returns_done():
"""Test that a SessionNotFoundError from the raw_step() callback propagates as a """

def make_failing_raw_step(error_msg=None):
def failing_raw_step(*args, **kwargs):
"""A callback that returns done=True."""
info = {}
if error_msg:
info["error_details"] = error_msg
return [], None, True, info

return failing_raw_step

spaces = [
ObservationSpace(
name="ir",
string_size_range=ScalarRange(min=ScalarLimit(value=0)),
)
]

observation = ObservationView(make_failing_raw_step(), spaces)
with pytest.raises(ServiceError, match=r"^Failed to compute observation 'ir'$"):
observation["ir"] # pylint: disable=pointless-statement

observation = ObservationView(make_failing_raw_step("Oh no!"), spaces)
with pytest.raises(
ServiceError, match=r"^Failed to compute observation 'ir': Oh no!$"
):
observation["ir"] # pylint: disable=pointless-statement


if __name__ == "__main__":
main()

0 comments on commit 5354ddd

Please sign in to comment.