Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved error reporting from ObservationView.__getitem__(). #380

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions compiler_gym/views/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, List

from compiler_gym.service.connection import ServiceError
from compiler_gym.service.proto import ObservationSpace
from compiler_gym.util.gym_type_hints import (
ActionType,
Expand Down Expand Up @@ -58,14 +59,29 @@ def __getitem__(self, observation_space: str) -> ObservationType:

:raises SessionNotFound: If :meth:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>` has not been called.

:raises ServiceError: If the backend service fails to compute the
observation, or reports that a terminal state has been reached.
"""
observation_space: ObservationSpaceSpec = self.spaces[observation_space]
observations, _, _, _ = self._raw_step(
observations, _, done, info = self._raw_step(
actions=[], observations=[observation_space], rewards=[]
)
assert (
len(observations) == 1
), f"Expected 1 observation. Received: {len(observations)}"

if done:
# Computing an observation should never cause a terminal state since
# no action has been applied.
msg = f"Failed to compute observation '{observation_space.id}'"
if info.get("error_details"):
msg += f": {info['error_details']}"
raise ServiceError(msg)

if len(observations) != 1:
raise ServiceError(
f"Expected 1 '{observation_space.id}' observation "
f"but the service returned {len(observations)}"
)

return observations[0]

def _add_space(self, space: ObservationSpaceSpec):
Expand Down
73 changes: 69 additions & 4 deletions tests/views/observation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest

from compiler_gym.service.connection import ServiceError
from compiler_gym.service.proto import (
ObservationSpace,
ScalarLimit,
Expand All @@ -16,8 +17,8 @@
from tests.test_main import main


class MockGetObservation:
"""Mock for the get_observation callack of ObservationView."""
class MockRawStep:
"""Mock for the raw_step callack of ObservationView."""

def __init__(self, ret=None):
self.called_observation_spaces = []
Expand All @@ -35,7 +36,7 @@ def __call__(self, actions, observations, rewards):

def test_empty_space():
with pytest.raises(ValueError) as ctx:
ObservationView(MockGetObservation(), [])
ObservationView(MockRawStep(), [])
assert str(ctx.value) == "No observation spaces"


Expand Down Expand Up @@ -73,7 +74,7 @@ def test_observed_value_types():
),
),
]
mock = MockGetObservation(
mock = MockRawStep(
ret=[
"Hello, IR",
[1.0, 2.0],
Expand Down Expand Up @@ -122,5 +123,69 @@ def test_observed_value_types():
assert mock.called_observation_spaces == ["ir", "dfeat", "features", "binary"]


def test_observation_when_raw_step_returns_incorrect_no_of_observations():
"""Test that a ServiceError is propagated when raw_step() returns unexpected
number of observations."""

def make_failing_raw_step(n: int):
def failing_raw_step(*args, **kwargs):
"""A callback that returns done=True."""
del args # Unused
del kwargs # Unused
return ["ir"] * n, None, False, {}

return failing_raw_step

spaces = [
ObservationSpace(
name="ir",
string_size_range=ScalarRange(min=ScalarLimit(value=0)),
)
]

observation = ObservationView(make_failing_raw_step(0), spaces)
with pytest.raises(
ServiceError, match=r"^Expected 1 'ir' observation but the service returned 0$"
):
observation["ir"]

observation = ObservationView(make_failing_raw_step(3), spaces)
with pytest.raises(
ServiceError, match=r"^Expected 1 'ir' observation but the service returned 3$"
):
observation["ir"]


def test_observation_when_raw_step_returns_done():
"""Test that a SessionNotFoundError from the raw_step() callback propagates as a """

def make_failing_raw_step(error_msg=None):
def failing_raw_step(*args, **kwargs):
"""A callback that returns done=True."""
info = {}
if error_msg:
info["error_details"] = error_msg
return [], None, True, info

return failing_raw_step

spaces = [
ObservationSpace(
name="ir",
string_size_range=ScalarRange(min=ScalarLimit(value=0)),
)
]

observation = ObservationView(make_failing_raw_step(), spaces)
with pytest.raises(ServiceError, match=r"^Failed to compute observation 'ir'$"):
observation["ir"] # pylint: disable=pointless-statement

observation = ObservationView(make_failing_raw_step("Oh no!"), spaces)
with pytest.raises(
ServiceError, match=r"^Failed to compute observation 'ir': Oh no!$"
):
observation["ir"] # pylint: disable=pointless-statement


if __name__ == "__main__":
main()