Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[datasets] Add a random_benchmark() method. #247

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler_gym/bin/manual_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def do_set_benchmark(self, arg):
Use '-' for a random benchmark.
"""
if arg == "-":
arg = self.env.datasets.benchmark().uri
arg = self.env.datasets.random_benchmark().uri
print(f"set_benchmark {arg}")

try:
Expand Down
24 changes: 24 additions & 0 deletions compiler_gym/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import Dict, Iterable, Optional, Union

import numpy as np
from deprecated.sphinx import deprecated as mark_deprecated

from compiler_gym.datasets.benchmark import Benchmark
Expand Down Expand Up @@ -358,6 +359,29 @@ def benchmark(self, uri: str) -> Benchmark:
"""
raise NotImplementedError("abstract class")

def random_benchmark(
self, random_state: Optional[np.random.Generator] = None
) -> Benchmark:
"""Select a benchmark randomly.

:param random_state: A random number generator. If not provided, a
default :code:`np.random.default_rng()` is used.

:return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>`
instance.
"""
random_state = random_state or np.random.default_rng()
return self._random_benchmark(random_state)

def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
"""Private implementation of the random benchmark getter.

Subclasses must implement this method so that it selects a benchmark
from the available benchmarks with uniform probability, using only
:code:`random_state` as a source of randomness.
"""
raise NotImplementedError("abstract class")

def __getitem__(self, uri: str) -> Benchmark:
"""Select a benchmark by URI.

Expand Down
40 changes: 39 additions & 1 deletion compiler_gym/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import deque
from typing import Dict, Iterable, Set, TypeVar
from typing import Dict, Iterable, Optional, Set, TypeVar

import numpy as np

from compiler_gym.datasets.benchmark import Benchmark
from compiler_gym.datasets.dataset import Dataset
Expand Down Expand Up @@ -251,6 +253,42 @@ def benchmark(self, uri: str) -> Benchmark:

return dataset.benchmark(uri)

def random_benchmark(
self, random_state: Optional[np.random.Generator] = None
) -> Benchmark:
"""Select a benchmark randomly.

First, a dataset is selected uniformly randomly using
:code:`random_state.choice(list(datasets))`. The
:meth:`random_benchmark()
<compiler_gym.datasets.Dataset.random_benchmark>` method of that dataset
is then called to select a benchmark.

Note that the distribution of benchmarks selected by this method is not
biased by the size of each dataset, since datasets are selected
uniformly. This means that datasets with a small number of benchmarks
will be overrepresented compared to datasets with many benchmarks. To
correct for this bias, use the number of benchmarks in each dataset as
a weight for the random selection:

>>> rng = np.random.default_rng()
>>> finite_datasets = [d for d in env.datasets if len(d) != math.inf]
>>> dataset = rng.choice(
finite_datasets,
p=[len(d) for d in finite_datasets]
)
>>> dataset.random_benchmark(random_state=rng)

:param random_state: A random number generator. If not provided, a
default :code:`np.random.default_rng()` is used.

:return: A :class:`Benchmark <compiler_gym.datasets.Benchmark>`
instance.
"""
random_state = random_state or np.random.default_rng()
dataset = random_state.choice(list(self._visible_datasets))
return self[dataset].random_benchmark(random_state=random_state)

@property
def size(self) -> int:
return len(self._visible_datasets)
Expand Down
5 changes: 5 additions & 0 deletions compiler_gym/datasets/files_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from typing import Iterable, List

import numpy as np

from compiler_gym.datasets.dataset import Benchmark, Dataset
from compiler_gym.util.decorators import memoized_property

Expand Down Expand Up @@ -117,3 +119,6 @@ def benchmark(self, uri: str) -> Benchmark:
if not abspath.is_file():
raise LookupError(f"Benchmark not found: {uri} (file not found: {abspath})")
return self.benchmark_class.from_file(uri, abspath)

def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
return self.benchmark(random_state.choice(list(self.benchmark_uris())))
5 changes: 5 additions & 0 deletions compiler_gym/envs/llvm/datasets/csmith.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from threading import Lock
from typing import Iterable, List

import numpy as np
from fasteners import InterProcessLock

from compiler_gym.datasets import Benchmark, BenchmarkSource, Dataset
Expand Down Expand Up @@ -227,6 +228,10 @@ def benchmark_uris(self) -> Iterable[str]:
def benchmark(self, uri: str) -> CsmithBenchmark:
return self.benchmark_from_seed(int(uri.split("/")[-1]))

def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
seed = random_state.integers(UINT_MAX)
return self.benchmark_from_seed(seed)

def benchmark_from_seed(self, seed: int) -> CsmithBenchmark:
"""Get a benchmark from a uint32 seed.

Expand Down
6 changes: 6 additions & 0 deletions compiler_gym/envs/llvm/datasets/llvm_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from typing import Iterable

import numpy as np

from compiler_gym.datasets import Benchmark, Dataset
from compiler_gym.datasets.benchmark import BenchmarkInitError
from compiler_gym.third_party import llvm
Expand Down Expand Up @@ -56,6 +58,10 @@ def benchmark_uris(self) -> Iterable[str]:
def benchmark(self, uri: str) -> Benchmark:
return self.benchmark_from_seed(int(uri.split("/")[-1]))

def _random_benchmark(self, random_state: np.random.Generator) -> Benchmark:
seed = random_state.integers(UINT_MAX)
return self.benchmark_from_seed(seed)

def benchmark_from_seed(self, seed: int) -> Benchmark:
"""Get a benchmark from a uint32 seed.

Expand Down
35 changes: 28 additions & 7 deletions tests/datasets/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Unit tests for //compiler_gym/datasets."""
import numpy as np
import pytest

from compiler_gym.datasets.datasets import Datasets, round_robin_iterables
Expand Down Expand Up @@ -33,13 +34,14 @@ def benchmark_uris(self):
def benchmarks(self):
yield from self.benchmark_values

def benchmark(self, uri=None):
if uri:
for b in self.benchmark_values:
if b.uri == uri:
return b
raise KeyError(uri)
return self.benchmark_values[0]
def benchmark(self, uri):
for b in self.benchmark_values:
if b.uri == uri:
return b
raise KeyError(uri)

def random_benchmark(self, random_state=None):
return random_state.choice(self.benchmark_values)

def __repr__(self):
return str(self.name)
Expand Down Expand Up @@ -243,5 +245,24 @@ def test_benchmarks_iter_deprecated():
]


def test_random_benchmark(mocker):
da = MockDataset("benchmark://foo-v0")
ba = MockBenchmark(uri="benchmark://foo-v0/abc")
da.benchmark_values.append(ba)
datasets = Datasets([da])

mocker.spy(da, "random_benchmark")

num_benchmarks = 5
rng = np.random.default_rng(0)
random_benchmarks = {
b.uri for b in (datasets.random_benchmark(rng) for _ in range(num_benchmarks))
}

assert da.random_benchmark.call_count == num_benchmarks
assert len(random_benchmarks) == 1
assert next(iter(random_benchmarks)) == "benchmark://foo-v0/abc"


if __name__ == "__main__":
main()
13 changes: 13 additions & 0 deletions tests/datasets/files_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
from pathlib import Path

import numpy as np
import pytest

from compiler_gym.datasets import FilesDataset
Expand Down Expand Up @@ -111,5 +112,17 @@ def test_populated_dataset_with_file_extension_filter(populated_dataset: FilesDa
assert populated_dataset.size == 2


def test_populated_dataset_random_benchmark(populated_dataset: FilesDataset):
num_benchmarks = 3
rng = np.random.default_rng(0)
random_benchmarks = {
b.uri
for b in (
populated_dataset.random_benchmark(rng) for _ in range(num_benchmarks)
)
}
assert len(random_benchmarks) == num_benchmarks


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions tests/llvm/datasets/csmith_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

import gym
import numpy as np
import pytest

import compiler_gym.envs.llvm # noqa register environments
Expand Down Expand Up @@ -47,5 +48,16 @@ def test_csmith_random_select(
assert (tmpwd / "source.c").is_file()


@skip_on_ci
def test_random_benchmark(csmith_dataset: CsmithDataset):
num_benchmarks = 5
rng = np.random.default_rng(0)
random_benchmarks = {
b.uri
for b in (csmith_dataset.random_benchmark(rng) for _ in range(num_benchmarks))
}
assert len(random_benchmarks) == num_benchmarks


if __name__ == "__main__":
main()
14 changes: 14 additions & 0 deletions tests/llvm/datasets/llvm_stress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from itertools import islice

import gym
import numpy as np
import pytest

import compiler_gym.envs.llvm # noqa register environments
Expand Down Expand Up @@ -59,5 +60,18 @@ def test_llvm_stress_random_select(
assert instcount["TotalInstsCount"] > 0


@skip_on_ci
def test_random_benchmark(llvm_stress_dataset: LlvmStressDataset):
num_benchmarks = 5
rng = np.random.default_rng(0)
random_benchmarks = {
b.uri
for b in (
llvm_stress_dataset.random_benchmark(rng) for _ in range(num_benchmarks)
)
}
assert len(random_benchmarks) == num_benchmarks


if __name__ == "__main__":
main()