Skip to content

Commit

Permalink
Make derived observation spaces useable at construction time.
Browse files Browse the repository at this point in the history
Fixes #461.
  • Loading branch information
ChrisCummins committed Oct 11, 2021
1 parent 483c694 commit 795588b
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 101 deletions.
10 changes: 10 additions & 0 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
observation_space: Optional[Union[str, ObservationSpaceSpec]] = None,
reward_space: Optional[Union[str, Reward]] = None,
action_space: Optional[str] = None,
derived_observation_spaces: Optional[List[Dict[str, Any]]] = None,
connection_settings: Optional[ConnectionOpts] = None,
service_connection: Optional[CompilerGymServiceConnection] = None,
logger: Optional[logging.Logger] = None,
Expand Down Expand Up @@ -205,6 +206,10 @@ def __init__(
:param action_space: The name of the action space to use. If not
specified, the default action space for this compiler is used.
:param derived_observation_spaces: An optional list of arguments to be
passed to :meth:`env.observation.add_derived_space()
<compiler_gym.views.observation.Observation.add_derived_space>`.
:param connection_settings: The settings used to establish a connection
with the remote service.
Expand Down Expand Up @@ -299,6 +304,11 @@ def __init__(
)
self.reward = self._reward_view_type(rewards, self.observation)

# Register any derived observation spaces now so that the observation
# space can be set below.
for derived_observation_space in derived_observation_spaces or []:
self.observation.add_derived_space_internal(**derived_observation_space)

# Lazily evaluated version strings.
self._versions: Optional[GetVersionReply] = None

Expand Down
204 changes: 103 additions & 101 deletions compiler_gym/envs/llvm/llvm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
# First perform a one-time download of LLVM binaries that are needed by
# the LLVM service and are not included by the pip-installed package.
download_llvm_files()
self.inst2vec = _INST2VEC_ENCODER
super().__init__(
*args,
**kwargs,
Expand Down Expand Up @@ -150,15 +151,115 @@ def __init__(
platform_dependent=True,
),
],
derived_observation_spaces=[
{
"id": "Inst2vecPreprocessedText",
"base_id": "Ir",
"space": Sequence(
name="Inst2vecPreprocessedText", size_range=(0, None), dtype=str
),
"translate": self.inst2vec.preprocess,
"default_value": "",
},
{
"id": "Inst2vecEmbeddingIndices",
"base_id": "Ir",
"space": Sequence(
name="Inst2vecEmbeddingIndices",
size_range=(0, None),
dtype=np.int32,
),
"translate": lambda base_observation: self.inst2vec.encode(
self.inst2vec.preprocess(base_observation)
),
"default_value": np.array([self.inst2vec.vocab["!UNK"]]),
},
{
"id": "Inst2vec",
"base_id": "Ir",
"space": Sequence(
name="Inst2vec", size_range=(0, None), dtype=np.ndarray
),
"translate": lambda base_observation: self.inst2vec.embed(
self.inst2vec.encode(self.inst2vec.preprocess(base_observation))
),
"default_value": np.vstack(
[self.inst2vec.embeddings[self.inst2vec.vocab["!UNK"]]]
),
},
{
"id": "InstCountDict",
"base_id": "InstCount",
"space": DictSpace(
{
f"{name}Count": Scalar(
name=f"{name}Count", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES
},
name="InstCountDict",
),
"translate": lambda base_observation: {
f"{name}Count": val
for name, val in zip(INST_COUNT_FEATURE_NAMES, base_observation)
},
},
{
"id": "InstCountNorm",
"base_id": "InstCount",
"space": Box(
name="InstCountNorm",
low=0,
high=1,
shape=(len(INST_COUNT_FEATURE_NAMES) - 1,),
dtype=np.float32,
),
"translate": lambda base_observation: (
base_observation[1:] / max(base_observation[0], 1)
).astype(np.float32),
},
{
"id": "InstCountNormDict",
"base_id": "InstCountNorm",
"space": DictSpace(
{
f"{name}Density": Scalar(
name=f"{name}Density", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES[1:]
},
name="InstCountNormDict",
),
"translate": lambda base_observation: {
f"{name}Density": val
for name, val in zip(
INST_COUNT_FEATURE_NAMES[1:], base_observation
)
},
},
{
"id": "AutophaseDict",
"base_id": "Autophase",
"space": DictSpace(
{
name: Scalar(name=name, min=0, max=None, dtype=int)
for name in AUTOPHASE_FEATURE_NAMES
},
name="AutophaseDict",
),
"translate": lambda base_observation: {
name: val
for name, val in zip(AUTOPHASE_FEATURE_NAMES, base_observation)
},
},
],
)

# Mutable runtime configuration options that must be set on every call
# to reset.
self._runtimes_per_observation_count: Optional[int] = None
self._runtimes_warmup_per_observation_count: Optional[int] = None

self.inst2vec = _INST2VEC_ENCODER

cpu_info_spaces = [
Sequence(name="name", size_range=(0, None), dtype=str),
Scalar(name="cores_count", min=None, max=None, dtype=int),
Expand All @@ -178,105 +279,6 @@ def __init__(
name="CpuInfo",
)

self.observation.add_derived_space(
id="Inst2vecPreprocessedText",
base_id="Ir",
space=Sequence(
name="Inst2vecPreprocessedText", size_range=(0, None), dtype=str
),
translate=self.inst2vec.preprocess,
default_value="",
)
self.observation.add_derived_space(
id="Inst2vecEmbeddingIndices",
base_id="Ir",
space=Sequence(
name="Inst2vecEmbeddingIndices", size_range=(0, None), dtype=np.int32
),
translate=lambda base_observation: self.inst2vec.encode(
self.inst2vec.preprocess(base_observation)
),
default_value=np.array([self.inst2vec.vocab["!UNK"]]),
)
self.observation.add_derived_space(
id="Inst2vec",
base_id="Ir",
space=Sequence(name="Inst2vec", size_range=(0, None), dtype=np.ndarray),
translate=lambda base_observation: self.inst2vec.embed(
self.inst2vec.encode(self.inst2vec.preprocess(base_observation))
),
default_value=np.vstack(
[self.inst2vec.embeddings[self.inst2vec.vocab["!UNK"]]]
),
)

self.observation.add_derived_space(
id="InstCountDict",
base_id="InstCount",
space=DictSpace(
{
f"{name}Count": Scalar(
name=f"{name}Count", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES
},
name="InstCountDict",
),
translate=lambda base_observation: {
f"{name}Count": val
for name, val in zip(INST_COUNT_FEATURE_NAMES, base_observation)
},
)

self.observation.add_derived_space(
id="InstCountNorm",
base_id="InstCount",
space=Box(
name="InstCountNorm",
low=0,
high=1,
shape=(len(INST_COUNT_FEATURE_NAMES) - 1,),
dtype=np.float32,
),
translate=lambda base_observation: (
base_observation[1:] / max(base_observation[0], 1)
).astype(np.float32),
)

self.observation.add_derived_space(
id="InstCountNormDict",
base_id="InstCountNorm",
space=DictSpace(
{
f"{name}Density": Scalar(
name=f"{name}Density", min=0, max=None, dtype=int
)
for name in INST_COUNT_FEATURE_NAMES[1:]
},
name="InstCountNormDict",
),
translate=lambda base_observation: {
f"{name}Density": val
for name, val in zip(INST_COUNT_FEATURE_NAMES[1:], base_observation)
},
)

self.observation.add_derived_space(
id="AutophaseDict",
base_id="Autophase",
space=DictSpace(
{
name: Scalar(name=name, min=0, max=None, dtype=int)
for name in AUTOPHASE_FEATURE_NAMES
},
name="AutophaseDict",
),
translate=lambda base_observation: {
name: val
for name, val in zip(AUTOPHASE_FEATURE_NAMES, base_observation)
},
)

def reset(self, *args, **kwargs):
try:
observation = super().reset(*args, **kwargs)
Expand Down
13 changes: 13 additions & 0 deletions compiler_gym/views/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,18 @@ def add_derived_space(
base_space = self.spaces[base_id]
self._add_space(base_space.make_derived_space(id=id, **kwargs))

# NOTE(github.com/facebookresearch/CompilerGym/issues/461): This method will
# be renamed to add_derived_space() once the current method with that name
# is removed.
def add_derived_space_internal(
self,
id: str,
base_id: str,
**kwargs,
) -> None:
"""Internal API for adding a new observation space."""
base_space = self.spaces[base_id]
self._add_space(base_space.make_derived_space(id=id, **kwargs))

def __repr__(self):
return f"ObservationView[{', '.join(sorted(self.spaces.keys()))}]"
9 changes: 9 additions & 0 deletions tests/llvm/observation_spaces_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from typing import Any, Dict, List

import gym
import networkx as nx
import numpy as np
import pytest
Expand Down Expand Up @@ -1399,5 +1400,13 @@ def test_add_derived_space(env: LlvmEnv):
assert value == [15]


def test_derived_space_constructor():
"""Test that derived observation space can be specified at construction
time.
"""
with gym.make("llvm-v0", observation_space="AutophaseDict"):
pass


if __name__ == "__main__":
main()

0 comments on commit 795588b

Please sign in to comment.