Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make derived observation spaces compatible with constructor. #463

Merged
merged 3 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
observation_space: Optional[Union[str, ObservationSpaceSpec]] = None,
reward_space: Optional[Union[str, Reward]] = None,
action_space: Optional[str] = None,
derived_observation_spaces: Optional[List[Dict[str, Any]]] = None,
connection_settings: Optional[ConnectionOpts] = None,
service_connection: Optional[CompilerGymServiceConnection] = None,
logger: Optional[logging.Logger] = None,
Expand Down Expand Up @@ -205,6 +206,10 @@ def __init__(
:param action_space: The name of the action space to use. If not
specified, the default action space for this compiler is used.

:param derived_observation_spaces: An optional list of arguments to be
passed to :meth:`env.observation.add_derived_space()
<compiler_gym.views.observation.Observation.add_derived_space>`.

:param connection_settings: The settings used to establish a connection
with the remote service.

Expand Down Expand Up @@ -299,6 +304,11 @@ def __init__(
)
self.reward = self._reward_view_type(rewards, self.observation)

# Register any derived observation spaces now so that the observation
# space can be set below.
for derived_observation_space in derived_observation_spaces or []:
self.observation.add_derived_space_internal(**derived_observation_space)

# Lazily evaluated version strings.
self._versions: Optional[GetVersionReply] = None

Expand Down
204 changes: 103 additions & 101 deletions compiler_gym/envs/llvm/llvm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
# First perform a one-time download of LLVM binaries that are needed by
# the LLVM service and are not included by the pip-installed package.
download_llvm_files()
self.inst2vec = _INST2VEC_ENCODER
super().__init__(
*args,
**kwargs,
Expand Down Expand Up @@ -150,15 +151,115 @@ def __init__(
platform_dependent=True,
),
],
derived_observation_spaces=[
{
"id": "Inst2vecPreprocessedText",
"base_id": "Ir",
"space": Sequence(
name="Inst2vecPreprocessedText", size_range=(0, None), dtype=str
),
"translate": self.inst2vec.preprocess,
"default_value": "",
},
{
"id": "Inst2vecEmbeddingIndices",
"base_id": "Ir",
"space": Sequence(
name="Inst2vecEmbeddingIndices",
size_range=(0, None),
dtype=np.int32,
),
"translate": lambda base_observation: self.inst2vec.encode(
self.inst2vec.preprocess(base_observation)
),
"default_value": np.array([self.inst2vec.vocab["!UNK"]]),
},
{
"id": "Inst2vec",
"base_id": "Ir",
"space": Sequence(
name="Inst2vec", size_range=(0, None), dtype=np.ndarray
),
"translate": lambda base_observation: self.inst2vec.embed(
self.inst2vec.encode(self.inst2vec.preprocess(base_observation))
),
"default_value": np.vstack(
[self.inst2vec.embeddings[self.inst2vec.vocab["!UNK"]]]
),
},
{
"id": "InstCountDict",
"base_id": "InstCount",
"space": DictSpace(
{
f"{name}Count": Scalar(
name=f"{name}Count", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES
},
name="InstCountDict",
),
"translate": lambda base_observation: {
f"{name}Count": val
for name, val in zip(INST_COUNT_FEATURE_NAMES, base_observation)
},
},
{
"id": "InstCountNorm",
"base_id": "InstCount",
"space": Box(
name="InstCountNorm",
low=0,
high=1,
shape=(len(INST_COUNT_FEATURE_NAMES) - 1,),
dtype=np.float32,
),
"translate": lambda base_observation: (
base_observation[1:] / max(base_observation[0], 1)
).astype(np.float32),
},
{
"id": "InstCountNormDict",
"base_id": "InstCountNorm",
"space": DictSpace(
{
f"{name}Density": Scalar(
name=f"{name}Density", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES[1:]
},
name="InstCountNormDict",
),
"translate": lambda base_observation: {
f"{name}Density": val
for name, val in zip(
INST_COUNT_FEATURE_NAMES[1:], base_observation
)
},
},
{
"id": "AutophaseDict",
"base_id": "Autophase",
"space": DictSpace(
{
name: Scalar(name=name, min=0, max=None, dtype=int)
for name in AUTOPHASE_FEATURE_NAMES
},
name="AutophaseDict",
),
"translate": lambda base_observation: {
name: val
for name, val in zip(AUTOPHASE_FEATURE_NAMES, base_observation)
},
},
],
)

# Mutable runtime configuration options that must be set on every call
# to reset.
self._runtimes_per_observation_count: Optional[int] = None
self._runtimes_warmup_per_observation_count: Optional[int] = None

self.inst2vec = _INST2VEC_ENCODER

cpu_info_spaces = [
Sequence(name="name", size_range=(0, None), dtype=str),
Scalar(name="cores_count", min=None, max=None, dtype=int),
Expand All @@ -178,105 +279,6 @@ def __init__(
name="CpuInfo",
)

self.observation.add_derived_space(
id="Inst2vecPreprocessedText",
base_id="Ir",
space=Sequence(
name="Inst2vecPreprocessedText", size_range=(0, None), dtype=str
),
translate=self.inst2vec.preprocess,
default_value="",
)
self.observation.add_derived_space(
id="Inst2vecEmbeddingIndices",
base_id="Ir",
space=Sequence(
name="Inst2vecEmbeddingIndices", size_range=(0, None), dtype=np.int32
),
translate=lambda base_observation: self.inst2vec.encode(
self.inst2vec.preprocess(base_observation)
),
default_value=np.array([self.inst2vec.vocab["!UNK"]]),
)
self.observation.add_derived_space(
id="Inst2vec",
base_id="Ir",
space=Sequence(name="Inst2vec", size_range=(0, None), dtype=np.ndarray),
translate=lambda base_observation: self.inst2vec.embed(
self.inst2vec.encode(self.inst2vec.preprocess(base_observation))
),
default_value=np.vstack(
[self.inst2vec.embeddings[self.inst2vec.vocab["!UNK"]]]
),
)

self.observation.add_derived_space(
id="InstCountDict",
base_id="InstCount",
space=DictSpace(
{
f"{name}Count": Scalar(
name=f"{name}Count", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES
},
name="InstCountDict",
),
translate=lambda base_observation: {
f"{name}Count": val
for name, val in zip(INST_COUNT_FEATURE_NAMES, base_observation)
},
)

self.observation.add_derived_space(
id="InstCountNorm",
base_id="InstCount",
space=Box(
name="InstCountNorm",
low=0,
high=1,
shape=(len(INST_COUNT_FEATURE_NAMES) - 1,),
dtype=np.float32,
),
translate=lambda base_observation: (
base_observation[1:] / max(base_observation[0], 1)
).astype(np.float32),
)

self.observation.add_derived_space(
id="InstCountNormDict",
base_id="InstCountNorm",
space=DictSpace(
{
f"{name}Density": Scalar(
name=f"{name}Density", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES[1:]
},
name="InstCountNormDict",
),
translate=lambda base_observation: {
f"{name}Density": val
for name, val in zip(INST_COUNT_FEATURE_NAMES[1:], base_observation)
},
)

self.observation.add_derived_space(
id="AutophaseDict",
base_id="Autophase",
space=DictSpace(
{
name: Scalar(name=name, min=0, max=None, dtype=int)
for name in AUTOPHASE_FEATURE_NAMES
},
name="AutophaseDict",
),
translate=lambda base_observation: {
name: val
for name, val in zip(AUTOPHASE_FEATURE_NAMES, base_observation)
},
)

def reset(self, *args, **kwargs):
try:
observation = super().reset(*args, **kwargs)
Expand Down
15 changes: 15 additions & 0 deletions compiler_gym/envs/llvm/make_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

Usage: make_specs.py <service_binary> <output_path>.
"""
import signal

# TODO: As we add support for more compilers we could generalize this script
# to work with other compiler services rather than hardcoding to LLVM.
import sys
Expand All @@ -19,11 +21,24 @@
) as f:
_FLAG_DESCRIPTIONS = [ln.rstrip() for ln in f.readlines()]

# The maximum number of seconds to wait before timing out.
TIMEOUT_SECONDS = 300


def timeout_handler(signum, frame):
del signum # unused
del frame # unused
print(f"error: Timeout reached after {TIMEOUT_SECONDS:,d} seconds", file=sys.stderr)
sys.exit(1)


def main(argv):
assert len(argv) == 3, "Usage: make_specs.py <service_binary> <output_path>"
service_path, output_path = argv[1:]

signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(TIMEOUT_SECONDS)

with LlvmEnv(Path(service_path)) as env:
with open(output_path, "w") as f:
print("from enum import Enum", file=f)
Expand Down
22 changes: 22 additions & 0 deletions compiler_gym/views/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, List

from deprecated.sphinx import deprecated

from compiler_gym.service.connection import ServiceError
from compiler_gym.service.proto import ObservationSpace
from compiler_gym.util.gym_type_hints import (
Expand Down Expand Up @@ -93,6 +95,13 @@ def _add_space(self, space: ObservationSpaceSpec):
# env.observation.FooBar().
setattr(self, space.id, lambda: self[space.id])

@deprecated(
version="0.2.1",
reason=(
"Use the derived_observation_spaces argument to CompilerEnv constructor. "
"See <https://github.com/facebookresearch/CompilerGym/issues/461>."
),
)
def add_derived_space(
self,
id: str,
Expand Down Expand Up @@ -126,5 +135,18 @@ def add_derived_space(
base_space = self.spaces[base_id]
self._add_space(base_space.make_derived_space(id=id, **kwargs))

# NOTE(github.com/facebookresearch/CompilerGym/issues/461): This method will
# be renamed to add_derived_space() once the current method with that name
# is removed.
def add_derived_space_internal(
self,
id: str,
base_id: str,
**kwargs,
) -> None:
"""Internal API for adding a new observation space."""
base_space = self.spaces[base_id]
self._add_space(base_space.make_derived_space(id=id, **kwargs))

def __repr__(self):
return f"ObservationView[{', '.join(sorted(self.spaces.keys()))}]"
30 changes: 24 additions & 6 deletions tests/llvm/observation_spaces_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from typing import Any, Dict, List

import gym
import networkx as nx
import numpy as np
import pytest
Expand Down Expand Up @@ -1379,12 +1380,15 @@ def test_is_buildable_observation_space_not_buildable(env: LlvmEnv):

def test_add_derived_space(env: LlvmEnv):
env.reset()
env.observation.add_derived_space(
id="IrLen",
base_id="Ir",
space=Box(name="IrLen", low=0, high=float("inf"), shape=(1,), dtype=int),
translate=lambda base: [15],
)
with pytest.deprecated_call(
match="Use the derived_observation_spaces argument to CompilerEnv constructor."
):
env.observation.add_derived_space(
id="IrLen",
base_id="Ir",
space=Box(name="IrLen", low=0, high=float("inf"), shape=(1,), dtype=int),
translate=lambda base: [15],
)

value = env.observation["IrLen"]
assert isinstance(value, list)
Expand All @@ -1396,5 +1400,19 @@ def test_add_derived_space(env: LlvmEnv):
assert value == [15]


def test_derived_space_constructor():
"""Test that derived observation space can be specified at construction
time.
"""
with gym.make("llvm-v0") as env:
env.observation_space = "AutophaseDict"
a = env.reset()

with gym.make("llvm-v0", observation_space="AutophaseDict") as env:
b = env.reset()

assert a == b


if __name__ == "__main__":
main()