Skip to content

Commit

Permalink
[datasets] Add a random_benchmark() method.
Browse files Browse the repository at this point in the history
The v0.1.8 release removed the random benchmark selection from
CompilerGym environments when no benchmark was specified. If the user
wishes for random benchmark selection, they were required to roll
their own implementation. Randomly sampling from
env.dataset.benchmark_uris() is not always easy as the generator may
be infinite. For some datasets, e.g. Csmith, it is trivial to select
random benchmarks by generating random numbers within the range of
numeric seed values, but this is not obvious and the user shouldn't
have to figure this out for the simple case of uniform random
selection.

This adds a `random_benchmark()` method to the `Dataset` class which
allows uniform random benchmark selection, and a `random_benchmark()`
method to the `Datasets` class for sampling across datasets.

Issue #240.
  • Loading branch information
ChrisCummins committed May 4, 2021
1 parent 6a24c6d commit 40bf310
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 2 deletions.
2 changes: 1 addition & 1 deletion compiler_gym/bin/manual_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def do_set_benchmark(self, arg):
Use '-' for a random benchmark.
"""
if arg == "-":
arg = self.env.datasets.benchmark().uri
arg = self.env.datasets.random_benchmark().uri
print(f"set_benchmark {arg}")

try:
Expand Down
24 changes: 24 additions & 0 deletions compiler_gym/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import Dict, Iterable, Optional, Union

import numpy as np
from deprecated.sphinx import deprecated as mark_deprecated

from compiler_gym.datasets.benchmark import Benchmark
Expand Down Expand Up @@ -358,6 +359,29 @@ def benchmark(self, uri: str) -> Benchmark:
"""
raise NotImplementedError("abstract class")

def random_benchmark(
self, random_state: Optional[np.random.Generator] = None
) -> Benchmark:
"""Select a benchmark randomly.
:param random_state: A random number generator. If not provided, a
default :code:`np.random.default_rng()` is used.
:return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>`
instance.
"""
random_state = random_state or np.random.default_rng()
return self._random_benchmark(random_state)

def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
"""Private implementation of the random benchmark getter.
Subclasses must implement this method so that it selects a benchmark
from the available benchmarks with uniform probability, using only
:code:`random_state` as a source of randomness.
"""
raise NotImplementedError("abstract class")

def __getitem__(self, uri: str) -> Benchmark:
"""Select a benchmark by URI.
Expand Down
40 changes: 39 additions & 1 deletion compiler_gym/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import deque
from typing import Dict, Iterable, Set, TypeVar
from typing import Dict, Iterable, Optional, Set, TypeVar

import numpy as np

from compiler_gym.datasets.benchmark import Benchmark
from compiler_gym.datasets.dataset import Dataset
Expand Down Expand Up @@ -251,6 +253,42 @@ def benchmark(self, uri: str) -> Benchmark:

return dataset.benchmark(uri)

def random_benchmark(
self, random_state: Optional[np.random.Generator] = None
) -> Benchmark:
"""Select a benchmark randomly.
First, a dataset is selected uniformly randomly using
:code:`random_state.choice(list(datasets))`. The
:meth:`random_benchmark()
<compiler_gym.datasets.Dataset.random_benchmark>` method of that dataset
is then called to select a benchmark.
Note that the distribution of benchmarks selected by this method is not
biased by the size of each dataset, since datasets are selected
uniformly. This means that datasets with a small number of benchmarks
will be overrepresented compared to datasets with many benchmarks. To
correct for this bias, use the number of benchmarks in each dataset as
a weight for the random selection:
>>> rng = np.random.default_rng()
>>> finite_datasets = [d for d in env.datasets if len(d) != math.inf]
>>> dataset = rng.choice(
finite_datasets,
p=[len(d) for d in finite_datasets]
)
>>> dataset.random_benchmark(random_state=rng)
:param random_state: A random number generator. If not provided, a
default :code:`np.random.default_rng()` is used.
:return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>`
instance.
"""
random_state = random_state or np.random.default_rng()
dataset = random_state.choice(list(self._visible_datasets))
return self[dataset].random_benchmark(random_state=random_state)

@property
def size(self) -> int:
return len(self._visible_datasets)
Expand Down
5 changes: 5 additions & 0 deletions compiler_gym/datasets/files_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from typing import Iterable, List

import numpy as np

from compiler_gym.datasets.dataset import Benchmark, Dataset
from compiler_gym.util.decorators import memoized_property

Expand Down Expand Up @@ -117,3 +119,6 @@ def benchmark(self, uri: str) -> Benchmark:
if not abspath.is_file():
raise LookupError(f"Benchmark not found: {uri} (file not found: {abspath})")
return self.benchmark_class.from_file(uri, abspath)

def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
return self.benchmark(random_state.choice(list(self.benchmark_uris())))
5 changes: 5 additions & 0 deletions compiler_gym/envs/llvm/datasets/csmith.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from threading import Lock
from typing import Iterable, List

import numpy as np
from fasteners import InterProcessLock

from compiler_gym.datasets import Benchmark, BenchmarkSource, Dataset
Expand Down Expand Up @@ -227,6 +228,10 @@ def benchmark_uris(self) -> Iterable[str]:
def benchmark(self, uri: str) -> CsmithBenchmark:
return self.benchmark_from_seed(int(uri.split("/")[-1]))

def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
seed = random_state.integers(UINT_MAX)
return self.benchmark_from_seed(seed)

def benchmark_from_seed(self, seed: int) -> CsmithBenchmark:
"""Get a benchmark from a uint32 seed.
Expand Down
6 changes: 6 additions & 0 deletions compiler_gym/envs/llvm/datasets/llvm_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from typing import Iterable

import numpy as np

from compiler_gym.datasets import Benchmark, Dataset
from compiler_gym.datasets.benchmark import BenchmarkInitError
from compiler_gym.third_party import llvm
Expand Down Expand Up @@ -56,6 +58,10 @@ def benchmark_uris(self) -> Iterable[str]:
def benchmark(self, uri: str) -> Benchmark:
return self.benchmark_from_seed(int(uri.split("/")[-1]))

def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
seed = random_state.integers(UINT_MAX)
return self.benchmark_from_seed(seed)

def benchmark_from_seed(self, seed: int) -> Benchmark:
"""Get a benchmark from a uint32 seed.
Expand Down
22 changes: 22 additions & 0 deletions tests/datasets/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Unit tests for //compiler_gym/datasets."""
import numpy as np
import pytest

from compiler_gym.datasets.datasets import Datasets, round_robin_iterables
Expand Down Expand Up @@ -243,5 +244,26 @@ def test_benchmarks_iter_deprecated():
]


def test_random_benchmark():
da = MockDataset("benchmark://foo-v0")
db = MockDataset("benchmark://bar-v0")
ba = MockBenchmark(uri="benchmark://foo-v0/abc")
bb = MockBenchmark(uri="benchmark://foo-v0/123")
bc = MockBenchmark(uri="benchmark://bar-v0/abc")
bd = MockBenchmark(uri="benchmark://bar-v0/123")
da.benchmark_values.append(ba)
da.benchmark_values.append(bb)
db.benchmark_values.append(bc)
db.benchmark_values.append(bd)
datasets = Datasets([da, db])

num_benchmarks = 5
rng = np.random.default_rng(0)
random_benchmarks = {
b.uri for b in (datasets.random_benchmark(rng) for _ in range(num_benchmarks))
}
assert len(random_benchmarks) == num_benchmarks


if __name__ == "__main__":
main()
13 changes: 13 additions & 0 deletions tests/datasets/files_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
from pathlib import Path

import numpy as np
import pytest

from compiler_gym.datasets import FilesDataset
Expand Down Expand Up @@ -111,5 +112,17 @@ def test_populated_dataset_with_file_extension_filter(populated_dataset: FilesDa
assert populated_dataset.size == 2


def test_populated_dataset_random_benchmark(populated_dataset: FilesDataset):
num_benchmarks = 3
rng = np.random.default_rng(0)
random_benchmarks = {
b.uri
for b in (
populated_dataset.random_benchmark(rng) for _ in range(num_benchmarks)
)
}
assert len(random_benchmarks) == num_benchmarks


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions tests/llvm/datasets/csmith_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

import gym
import numpy as np
import pytest

import compiler_gym.envs.llvm # noqa register environments
Expand Down Expand Up @@ -47,5 +48,16 @@ def test_csmith_random_select(
assert (tmpwd / "source.c").is_file()


@skip_on_ci
def test_random_benchmark(csmith_dataset: CsmithDataset):
num_benchmarks = 5
rng = np.random.default_rng(0)
random_benchmarks = {
b.uri
for b in (csmith_dataset.random_benchmark(rng) for _ in range(num_benchmarks))
}
assert len(random_benchmarks) == num_benchmarks


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions tests/llvm/datasets/llvm_stress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from itertools import islice

import gym
import numpy as np
import pytest

import compiler_gym.envs.llvm # noqa register environments
Expand Down Expand Up @@ -59,5 +60,18 @@ def test_llvm_stress_random_select(
assert instcount["TotalInstsCount"] > 0


@skip_on_ci
def test_random_benchmark(llvm_stress_dataset: LlvmStressDataset):
num_benchmarks = 5
rng = np.random.default_rng(0)
random_benchmarks = {
b.uri
for b in (
llvm_stress_dataset.random_benchmark(rng) for _ in range(num_benchmarks)
)
}
assert len(random_benchmarks) == num_benchmarks


if __name__ == "__main__":
main()

0 comments on commit 40bf310

Please sign in to comment.