Skip to content

Commit

Permalink
Merge pull request #683 from ChrisCummins/feature/counter-wrapper
Browse files Browse the repository at this point in the history
[wrappers] Add a Counter wrapper class.
  • Loading branch information
ChrisCummins authored May 17, 2022
2 parents 596c242 + 1ba722c commit f30ebba
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler_gym/wrappers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ py_library(
"__init__.py",
"commandline.py",
"core.py",
"counter.py",
"datasets.py",
"fork.py",
"llvm.py",
Expand Down
1 change: 1 addition & 0 deletions compiler_gym/wrappers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ set(WRAPPERS_SRCS
"__init__.py"
"commandline.py"
"core.py"
"counter.py"
"datasets.py"
"fork.py"
"time_limit.py"
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ObservationWrapper,
RewardWrapper,
)
from compiler_gym.wrappers.counter import Counter
from compiler_gym.wrappers.datasets import (
CycleOverBenchmarks,
CycleOverBenchmarksIterator,
Expand All @@ -61,6 +62,7 @@
"CommandlineWithTerminalAction",
"CompilerEnvWrapper",
"ConstrainedCommandline",
"Counter",
"CycleOverBenchmarks",
"CycleOverBenchmarksIterator",
"ForkOnStep",
Expand Down
59 changes: 59 additions & 0 deletions compiler_gym/wrappers/counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module implements a wrapper that counts calls to operations.
"""
from typing import Dict

from compiler_gym.envs import CompilerEnv
from compiler_gym.wrappers import CompilerEnvWrapper


class Counter(CompilerEnvWrapper):
"""A wrapper that counts the number of calls to its operations.
The counters are _not_ reset by :meth:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>`.
Example usage:
>>> env = Counter(compiler_gym.make("llvm-v0"))
>>> env.counters
{"close": 0, "reset": 0, "step": 0, "fork": 0}
>>> env.step(0)
{"close": 0, "reset": 0, "step": 1, "fork": 0}
:ivar counters: A dictionary of counters for different operation types.
:vartype counters: Dict[str, int]
"""

def __init__(self, env: CompilerEnv):
"""Constructor.
:param env: The environment to wrap.
"""
super().__init__(env)
self.counters: Dict[str, int] = {
"close": 0,
"reset": 0,
"step": 0,
"fork": 0,
}

def close(self) -> None:
self.counters["close"] += 1
self.env.close()

def reset(self, *args, **kwargs):
self.counters["reset"] += 1
return self.env.reset(*args, **kwargs)

def step(self, *args, **kwargs):
self.counters["step"] += 1
return self.env.step(*args, **kwargs)

def fork(self):
self.counters["fork"] += 1
return self.env.fork()
10 changes: 10 additions & 0 deletions tests/wrappers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ py_test(
],
)

py_test(
name = "counter_test",
srcs = ["counter_test.py"],
deps = [
"//compiler_gym/wrappers",
"//tests:test_main",
"//tests/pytest_plugins:llvm",
],
)

py_test(
name = "datasets_wrappers_test",
timeout = "short",
Expand Down
11 changes: 11 additions & 0 deletions tests/wrappers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ cg_py_test(
tests::test_main
)

cg_py_test(
NAME counter_test
SRCS "counter_test.py"
DEPS
compiler_gym::envs::llvm::llvm
compiler_gym::errors::errors
compiler_gym::wrappers::wrappers
tests::test_main
tests::pytest_plugins::llvm
)

cg_py_test(
NAME
datasets_wrappers_test
Expand Down
65 changes: 65 additions & 0 deletions tests/wrappers/counter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Unit tests for //compiler_gym/wrappers."""
from compiler_gym.envs.llvm import LlvmEnv
from compiler_gym.wrappers import Counter
from tests.test_main import main

pytest_plugins = ["tests.pytest_plugins.llvm"]


def test_Counter_reset(env: LlvmEnv):
with Counter(env) as env:
env.reset()
assert env.counters == {
"close": 0,
"fork": 0,
"reset": 1,
"step": 0,
}

env.reset()
assert env.counters == {
"close": 0,
"fork": 0,
"reset": 2,
"step": 0,
}


def test_Counter_step(env: LlvmEnv):
with Counter(env) as env:
env.reset()
env.step(0)
assert env.counters == {
"close": 0,
"fork": 0,
"reset": 1,
"step": 1,
}


def test_Counter_double_close(env: LlvmEnv):
with Counter(env) as env:
env.close()
env.close()
assert env.counters == {
"close": 2,
"fork": 0,
"reset": 0,
"step": 0,
}

# Implicit close in `with` statement.
assert env.counters == {
"close": 3,
"fork": 0,
"reset": 0,
"step": 0,
}


if __name__ == "__main__":
main()

0 comments on commit f30ebba

Please sign in to comment.