Skip to content

Commit

Permalink
Merge pull request #463 from ChrisCummins/issue-461
Browse files Browse the repository at this point in the history
Make derived observation spaces compatible with constructor.
  • Loading branch information
ChrisCummins authored Oct 12, 2021
2 parents 4d8e4dd + 44d8602 commit 178764f
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 107 deletions.
10 changes: 10 additions & 0 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
observation_space: Optional[Union[str, ObservationSpaceSpec]] = None,
reward_space: Optional[Union[str, Reward]] = None,
action_space: Optional[str] = None,
derived_observation_spaces: Optional[List[Dict[str, Any]]] = None,
connection_settings: Optional[ConnectionOpts] = None,
service_connection: Optional[CompilerGymServiceConnection] = None,
logger: Optional[logging.Logger] = None,
Expand Down Expand Up @@ -205,6 +206,10 @@ def __init__(
:param action_space: The name of the action space to use. If not
specified, the default action space for this compiler is used.
:param derived_observation_spaces: An optional list of arguments to be
passed to :meth:`env.observation.add_derived_space()
<compiler_gym.views.observation.Observation.add_derived_space>`.
:param connection_settings: The settings used to establish a connection
with the remote service.
Expand Down Expand Up @@ -299,6 +304,11 @@ def __init__(
)
self.reward = self._reward_view_type(rewards, self.observation)

# Register any derived observation spaces now so that the observation
# space can be set below.
for derived_observation_space in derived_observation_spaces or []:
self.observation.add_derived_space_internal(**derived_observation_space)

# Lazily evaluated version strings.
self._versions: Optional[GetVersionReply] = None

Expand Down
204 changes: 103 additions & 101 deletions compiler_gym/envs/llvm/llvm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
# First perform a one-time download of LLVM binaries that are needed by
# the LLVM service and are not included by the pip-installed package.
download_llvm_files()
self.inst2vec = _INST2VEC_ENCODER
super().__init__(
*args,
**kwargs,
Expand Down Expand Up @@ -150,15 +151,115 @@ def __init__(
platform_dependent=True,
),
],
derived_observation_spaces=[
{
"id": "Inst2vecPreprocessedText",
"base_id": "Ir",
"space": Sequence(
name="Inst2vecPreprocessedText", size_range=(0, None), dtype=str
),
"translate": self.inst2vec.preprocess,
"default_value": "",
},
{
"id": "Inst2vecEmbeddingIndices",
"base_id": "Ir",
"space": Sequence(
name="Inst2vecEmbeddingIndices",
size_range=(0, None),
dtype=np.int32,
),
"translate": lambda base_observation: self.inst2vec.encode(
self.inst2vec.preprocess(base_observation)
),
"default_value": np.array([self.inst2vec.vocab["!UNK"]]),
},
{
"id": "Inst2vec",
"base_id": "Ir",
"space": Sequence(
name="Inst2vec", size_range=(0, None), dtype=np.ndarray
),
"translate": lambda base_observation: self.inst2vec.embed(
self.inst2vec.encode(self.inst2vec.preprocess(base_observation))
),
"default_value": np.vstack(
[self.inst2vec.embeddings[self.inst2vec.vocab["!UNK"]]]
),
},
{
"id": "InstCountDict",
"base_id": "InstCount",
"space": DictSpace(
{
f"{name}Count": Scalar(
name=f"{name}Count", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES
},
name="InstCountDict",
),
"translate": lambda base_observation: {
f"{name}Count": val
for name, val in zip(INST_COUNT_FEATURE_NAMES, base_observation)
},
},
{
"id": "InstCountNorm",
"base_id": "InstCount",
"space": Box(
name="InstCountNorm",
low=0,
high=1,
shape=(len(INST_COUNT_FEATURE_NAMES) - 1,),
dtype=np.float32,
),
"translate": lambda base_observation: (
base_observation[1:] / max(base_observation[0], 1)
).astype(np.float32),
},
{
"id": "InstCountNormDict",
"base_id": "InstCountNorm",
"space": DictSpace(
{
f"{name}Density": Scalar(
name=f"{name}Density", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES[1:]
},
name="InstCountNormDict",
),
"translate": lambda base_observation: {
f"{name}Density": val
for name, val in zip(
INST_COUNT_FEATURE_NAMES[1:], base_observation
)
},
},
{
"id": "AutophaseDict",
"base_id": "Autophase",
"space": DictSpace(
{
name: Scalar(name=name, min=0, max=None, dtype=int)
for name in AUTOPHASE_FEATURE_NAMES
},
name="AutophaseDict",
),
"translate": lambda base_observation: {
name: val
for name, val in zip(AUTOPHASE_FEATURE_NAMES, base_observation)
},
},
],
)

# Mutable runtime configuration options that must be set on every call
# to reset.
self._runtimes_per_observation_count: Optional[int] = None
self._runtimes_warmup_per_observation_count: Optional[int] = None

self.inst2vec = _INST2VEC_ENCODER

cpu_info_spaces = [
Sequence(name="name", size_range=(0, None), dtype=str),
Scalar(name="cores_count", min=None, max=None, dtype=int),
Expand All @@ -178,105 +279,6 @@ def __init__(
name="CpuInfo",
)

self.observation.add_derived_space(
id="Inst2vecPreprocessedText",
base_id="Ir",
space=Sequence(
name="Inst2vecPreprocessedText", size_range=(0, None), dtype=str
),
translate=self.inst2vec.preprocess,
default_value="",
)
self.observation.add_derived_space(
id="Inst2vecEmbeddingIndices",
base_id="Ir",
space=Sequence(
name="Inst2vecEmbeddingIndices", size_range=(0, None), dtype=np.int32
),
translate=lambda base_observation: self.inst2vec.encode(
self.inst2vec.preprocess(base_observation)
),
default_value=np.array([self.inst2vec.vocab["!UNK"]]),
)
self.observation.add_derived_space(
id="Inst2vec",
base_id="Ir",
space=Sequence(name="Inst2vec", size_range=(0, None), dtype=np.ndarray),
translate=lambda base_observation: self.inst2vec.embed(
self.inst2vec.encode(self.inst2vec.preprocess(base_observation))
),
default_value=np.vstack(
[self.inst2vec.embeddings[self.inst2vec.vocab["!UNK"]]]
),
)

self.observation.add_derived_space(
id="InstCountDict",
base_id="InstCount",
space=DictSpace(
{
f"{name}Count": Scalar(
name=f"{name}Count", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES
},
name="InstCountDict",
),
translate=lambda base_observation: {
f"{name}Count": val
for name, val in zip(INST_COUNT_FEATURE_NAMES, base_observation)
},
)

self.observation.add_derived_space(
id="InstCountNorm",
base_id="InstCount",
space=Box(
name="InstCountNorm",
low=0,
high=1,
shape=(len(INST_COUNT_FEATURE_NAMES) - 1,),
dtype=np.float32,
),
translate=lambda base_observation: (
base_observation[1:] / max(base_observation[0], 1)
).astype(np.float32),
)

self.observation.add_derived_space(
id="InstCountNormDict",
base_id="InstCountNorm",
space=DictSpace(
{
f"{name}Density": Scalar(
name=f"{name}Density", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES[1:]
},
name="InstCountNormDict",
),
translate=lambda base_observation: {
f"{name}Density": val
for name, val in zip(INST_COUNT_FEATURE_NAMES[1:], base_observation)
},
)

self.observation.add_derived_space(
id="AutophaseDict",
base_id="Autophase",
space=DictSpace(
{
name: Scalar(name=name, min=0, max=None, dtype=int)
for name in AUTOPHASE_FEATURE_NAMES
},
name="AutophaseDict",
),
translate=lambda base_observation: {
name: val
for name, val in zip(AUTOPHASE_FEATURE_NAMES, base_observation)
},
)

def reset(self, *args, **kwargs):
try:
observation = super().reset(*args, **kwargs)
Expand Down
15 changes: 15 additions & 0 deletions compiler_gym/envs/llvm/make_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Usage: make_specs.py <service_binary> <output_path>.
"""
import signal

# TODO: As we add support for more compilers we could generalize this script
# to work with other compiler services rather than hardcoding to LLVM.
import sys
Expand All @@ -19,11 +21,24 @@
) as f:
_FLAG_DESCRIPTIONS = [ln.rstrip() for ln in f.readlines()]

# The maximum number of seconds to wait before timing out.
TIMEOUT_SECONDS = 300


def timeout_handler(signum, frame):
del signum # unused
del frame # unused
print(f"error: Timeout reached after {TIMEOUT_SECONDS:,d} seconds", file=sys.stderr)
sys.exit(1)


def main(argv):
assert len(argv) == 3, "Usage: make_specs.py <service_binary> <output_path>"
service_path, output_path = argv[1:]

signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(TIMEOUT_SECONDS)

with LlvmEnv(Path(service_path)) as env:
with open(output_path, "w") as f:
print("from enum import Enum", file=f)
Expand Down
22 changes: 22 additions & 0 deletions compiler_gym/views/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, List

from deprecated.sphinx import deprecated

from compiler_gym.service.connection import ServiceError
from compiler_gym.service.proto import ObservationSpace
from compiler_gym.util.gym_type_hints import (
Expand Down Expand Up @@ -93,6 +95,13 @@ def _add_space(self, space: ObservationSpaceSpec):
# env.observation.FooBar().
setattr(self, space.id, lambda: self[space.id])

@deprecated(
version="0.2.1",
reason=(
"Use the derived_observation_spaces argument to CompilerEnv constructor. "
"See <https://github.com/facebookresearch/CompilerGym/issues/461>."
),
)
def add_derived_space(
self,
id: str,
Expand Down Expand Up @@ -126,5 +135,18 @@ def add_derived_space(
base_space = self.spaces[base_id]
self._add_space(base_space.make_derived_space(id=id, **kwargs))

# NOTE(github.com/facebookresearch/CompilerGym/issues/461): This method will
# be renamed to add_derived_space() once the current method with that name
# is removed.
def add_derived_space_internal(
self,
id: str,
base_id: str,
**kwargs,
) -> None:
"""Internal API for adding a new observation space."""
base_space = self.spaces[base_id]
self._add_space(base_space.make_derived_space(id=id, **kwargs))

def __repr__(self):
return f"ObservationView[{', '.join(sorted(self.spaces.keys()))}]"
30 changes: 24 additions & 6 deletions tests/llvm/observation_spaces_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from typing import Any, Dict, List

import gym
import networkx as nx
import numpy as np
import pytest
Expand Down Expand Up @@ -1379,12 +1380,15 @@ def test_is_buildable_observation_space_not_buildable(env: LlvmEnv):

def test_add_derived_space(env: LlvmEnv):
env.reset()
env.observation.add_derived_space(
id="IrLen",
base_id="Ir",
space=Box(name="IrLen", low=0, high=float("inf"), shape=(1,), dtype=int),
translate=lambda base: [15],
)
with pytest.deprecated_call(
match="Use the derived_observation_spaces argument to CompilerEnv constructor."
):
env.observation.add_derived_space(
id="IrLen",
base_id="Ir",
space=Box(name="IrLen", low=0, high=float("inf"), shape=(1,), dtype=int),
translate=lambda base: [15],
)

value = env.observation["IrLen"]
assert isinstance(value, list)
Expand All @@ -1396,5 +1400,19 @@ def test_add_derived_space(env: LlvmEnv):
assert value == [15]


def test_derived_space_constructor():
"""Test that derived observation space can be specified at construction
time.
"""
with gym.make("llvm-v0") as env:
env.observation_space = "AutophaseDict"
a = env.reset()

with gym.make("llvm-v0", observation_space="AutophaseDict") as env:
b = env.reset()

assert a == b


if __name__ == "__main__":
main()

0 comments on commit 178764f

Please sign in to comment.