Skip to content

Commit

Permalink
Fix implementation of inheritace from ObservationWrapper and RewardWr…
Browse files Browse the repository at this point in the history
…apper
  • Loading branch information
sogartar committed Mar 28, 2022
1 parent 165b385 commit 3b9518a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/llvm_rl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.max = max
self.leakiness_factor = leakiness_factor

def reward(self, reward: float) -> float:
def convert_reward(self, reward: float) -> float:
if reward > self.max:
return self.max + (reward - self.max) * self.leakiness_factor
elif reward < self.min:
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self, env: CompilerEnv):
dtype=np.float32,
)

def observation(self, observation):
def convert_observation(self, observation):
if observation[self.TotalInsts_index] <= 0:
return np.zeros(observation.shape, dtype=np.float32)
return np.clip(
Expand Down
19 changes: 8 additions & 11 deletions tests/wrappers/core_wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,23 @@

from compiler_gym.datasets import Datasets
from compiler_gym.envs.llvm import LlvmEnv
from compiler_gym.wrappers import (
ActionWrapper,
CompilerEnvWrapper,
ObservationWrapper,
RewardWrapper,
)
from compiler_gym.wrappers import ActionWrapper, CompilerEnvWrapper
from compiler_gym.wrappers import ObservationWrapper as CoreObservationWrapper
from compiler_gym.wrappers import RewardWrapper as CoreRewardWrapper
from tests.test_main import main

pytest_plugins = ["tests.pytest_plugins.llvm"]


class ObservationDummyWrapper(ObservationWrapper):
class ObservationWrapper(CoreObservationWrapper):
def __init__(self, env):
super().__init__(env)

def convert_observation(self, observation):
return observation


class RewardDummyWrapper(RewardWrapper):
class RewardWrapper(CoreRewardWrapper):
def __init__(self, env):
super().__init__(env)

Expand All @@ -40,8 +37,8 @@ def convert_reward(self, reward):
params=[
ActionWrapper,
CompilerEnvWrapper,
ObservationDummyWrapper,
RewardDummyWrapper,
ObservationWrapper,
RewardWrapper,
],
)
def wrapper_type(request):
Expand Down Expand Up @@ -285,7 +282,7 @@ def convert_observation(self, observation):

def test_wrapped_observation_missing_definition(env: LlvmEnv):
with pytest.raises(TypeError):
env = ObservationWrapper(env)
env = CoreObservationWrapper(env)


def test_wrapped_reward(env: LlvmEnv):
Expand Down

0 comments on commit 3b9518a

Please sign in to comment.