-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[datasets] Add Dataset helper subclasses.
Issue #45.
- Loading branch information
1 parent
75730f1
commit 7b86fde
Showing
6 changed files
with
527 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import os | ||
from pathlib import Path | ||
from typing import Iterable, List, Optional | ||
|
||
from compiler_gym.datasets.dataset import Benchmark, Dataset | ||
from compiler_gym.util.decorators import memoized_property | ||
|
||
|
||
class FilesDataset(Dataset): | ||
"""A dataset comprising a directory tree files. | ||
A FilesDataset is a root directory that contains (possibly nested) files, | ||
where each file represents a benchmark. Files can be filtered on their | ||
expected filename suffix. | ||
Every file that matches the expected suffix is a benchmark. The URI of the | ||
benchmarks is the relative path of the file within the dataset, stripped of | ||
any matching file suffix. For example, given a dataset root around the | ||
directory tree :code:`/tmp/dataset` and with file suffix :code:`.txt`: | ||
.. code-block:: | ||
/tmp/dataset/a.txt | ||
/tmp/dataset/LICENSE | ||
/tmp/dataset/subdir/subdir/b.txt | ||
/tmp/dataset/subdir/subdir/c.txt | ||
a FilesDataset representing this directory tree will contain the following | ||
URIs: | ||
>>> list(dataset.benchmark_uris()) | ||
[ | ||
"benchamrk://ds-v0/a", | ||
"benchamrk://ds-v0/subdir/subdir/b", | ||
"benchamrk://ds-v0/subdir/subdir/c", | ||
] | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset_root: Path, | ||
benchmark_file_suffix: str = "", | ||
memoize_uris: bool = True, | ||
**dataset_args, | ||
): | ||
"""Constructor. | ||
:param dataset_root: The root directory to look for benchmark files. | ||
:param benchmark_file_suffix: A file extension that must be matched for | ||
a file to be used as a benchmark. | ||
:param memoize_uris: Whether to memoize the list of URIs contained in | ||
the dataset. Memoizing the URIs is a tradeoff between *O(n)* | ||
computation complexity of random access vs *O(n)* space complexity | ||
of memoizing the URI list. | ||
:param dataset_args: See :meth:`Dataset.__init__() | ||
<compiler_gym.datasets.Dataset.__init__>`. | ||
""" | ||
super().__init__(**dataset_args) | ||
self.dataset_root = dataset_root | ||
self.benchmark_file_suffix = benchmark_file_suffix | ||
self.memoize_uris = memoize_uris | ||
self._memoized_uris = None | ||
|
||
@memoized_property | ||
def n(self) -> int: # pylint: disable=invalid-overriden-method | ||
self.install() | ||
return sum( | ||
sum(1 for f in files if f.endswith(self.benchmark_file_suffix)) | ||
for (_, _, files) in os.walk(self.dataset_root) | ||
) | ||
|
||
@property | ||
def _benchmark_uris_iter(self) -> Iterable[str]: | ||
"""Return an iterator over benchmark URIs is consistent across runs.""" | ||
self.install() | ||
for root, dirs, files in os.walk(self.dataset_root): | ||
dirs.sort() | ||
reldir = root[len(str(self.dataset_root)) + 1 :] | ||
for filename in sorted(files): | ||
# If we have an expected file suffix then filter on it and strip | ||
# it from the filename. | ||
if self.benchmark_file_suffix: | ||
if not filename.endswith(self.benchmark_file_suffix): | ||
continue | ||
filename = filename[: -len(self.benchmark_file_suffix)] | ||
# Use os.path.join() rather than simple '/' concaentation as | ||
# reldir may be empty. | ||
yield os.path.join(self.name, reldir, filename) | ||
|
||
@property | ||
def _benchmark_uris(self) -> List[str]: | ||
return list(self._benchmark_uris_iter) | ||
|
||
def benchmark_uris(self) -> Iterable[str]: | ||
if self._memoized_uris: | ||
yield from self._memoized_uris | ||
elif self.memoize_uris: | ||
self._memoized_uris = self._benchmark_uris | ||
yield from self._memoized_uris | ||
else: | ||
yield from self._benchmark_uris_iter | ||
|
||
def benchmark(self, uri: Optional[str] = None) -> Benchmark: | ||
self.install() | ||
if uri is None or len(uri) <= len(self.name) + 1: | ||
if not self.n: | ||
raise ValueError("No benchmarks") | ||
return self.get_benchmark_by_index(self.random.integers(self.n)) | ||
|
||
relpath = f"{uri[len(self.name) + 1:]}{self.benchmark_file_suffix}" | ||
abspath = self.dataset_root / relpath | ||
if not abspath.is_file(): | ||
raise LookupError(f"Benchmark not found: {uri} (file not found: {abspath})") | ||
return self.benchmark_class.from_file(uri, abspath) | ||
|
||
def get_benchmark_by_index(self, n: int) -> Benchmark: | ||
"""Look a benchmark using a numeric index into the list of URIs.""" | ||
# If we have memoized the benchmark IDs then just index into the list. | ||
# Otherwise we will scan through the directory hierarchy. | ||
if self.memoize_uris: | ||
if not self._memoized_uris: | ||
self._memoized_uris = self._benchmark_uris | ||
return self.benchmark(self._memoized_uris[n]) | ||
|
||
i = 0 | ||
for root, dirs, files in os.walk(self.dataset_root): | ||
reldir = root[len(str(self.dataset_root)) + 1 :] | ||
|
||
# Filter only the files that match the target file suffix. | ||
valid_files = [f for f in files if f.endswith(self.benchmark_file_suffix)] | ||
|
||
if i + len(valid_files) < n: | ||
# There aren't enough files in this directory to bring us up to | ||
# the target file index, so skip the whole lot. | ||
i += len(valid_files) | ||
dirs.sort() | ||
else: | ||
# Iterate through the files in the directory in order until we | ||
# reach the target index. | ||
for filename in sorted(valid_files): | ||
if i == n: | ||
name_stem = filename[: -len(self.benchmark_file_suffix)] | ||
# Use os.path.join() rather than simple '/' concaentation as | ||
# reldir may be empty. | ||
uri = os.path.join(self.name, reldir, name_stem) | ||
return self.benchmark_class.from_file(uri, f"{root}/{filename}") | ||
i += 1 | ||
|
||
raise FileNotFoundError(f"Could not find benchmark with index {n} / {self.n}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import bz2 | ||
import gzip | ||
import io | ||
import os | ||
import shutil | ||
import tarfile | ||
from threading import Lock | ||
from typing import Iterable, List, Optional | ||
|
||
import fasteners | ||
|
||
from compiler_gym.datasets.files_dataset import FilesDataset | ||
from compiler_gym.util.decorators import memoized_property | ||
from compiler_gym.util.download import download | ||
|
||
|
||
class TarDataset(FilesDataset): | ||
"""A dataset comprising a files tree stored in a tar archive. | ||
This extends the :class:`FilesDataset` class by adding support for | ||
compressed archives of files. The archive is downloaded and unpacked | ||
on-demand. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
tar_urls: List[str], | ||
tar_sha256: Optional[str] = None, | ||
tar_compression: str = "bz2", | ||
strip_prefix: str = "", | ||
**dataset_args, | ||
): | ||
"""Constructor. | ||
:param tar_urls: A list of redundant URLS to download the tar archive from. | ||
:param tar_sha256: The SHA256 checksum of the downloaded tar archive. | ||
:param tar_compression: The tar archive compression type. One of | ||
{"bz2", "gz"}. | ||
:param strip_prefix: An optional path prefix to strip. Only files that | ||
match this path prefix will be used as benchmarks. | ||
:param dataset_args: See :meth:`FilesDataset.__init__() | ||
<compiler_gym.datasets.FilesDataset.__init__>`. | ||
""" | ||
super().__init__( | ||
dataset_root=None, # Set below once site_data_path is resolved. | ||
**dataset_args, | ||
) | ||
self.dataset_root = self.site_data_path / "contents" / strip_prefix | ||
|
||
self.tar_urls = tar_urls | ||
self.tar_sha256 = tar_sha256 | ||
self.tar_compression = tar_compression | ||
self.strip_prefix = strip_prefix | ||
|
||
self._installed = False | ||
self._tar_extracted_marker = self.site_data_path / ".extracted" | ||
self._tar_lock = Lock() | ||
self._tar_lockfile = self.site_data_path / "LOCK" | ||
|
||
@property | ||
def installed(self) -> bool: | ||
# Fast path for repeated checks to 'installed' without a disk op. | ||
if not self._installed: | ||
self._installed = self._tar_extracted_marker.is_file() | ||
return self._installed | ||
|
||
def install(self) -> None: | ||
if self.installed: | ||
return | ||
|
||
# Thread-level and process-level locks to prevent races. | ||
with self._tar_lock, fasteners.InterProcessLock(self._tar_lockfile): | ||
# Repeat the check to see if we have already installed the | ||
# dataset now that we have acquired the lock. | ||
if self.installed: | ||
return | ||
|
||
self.logger.info("Downloading %s dataset", self.name) | ||
tar_data = io.BytesIO(download(self.tar_urls, self.tar_sha256)) | ||
self.logger.info("Unpacking %s dataset", self.name) | ||
with tarfile.open( | ||
fileobj=tar_data, mode=f"r:{self.tar_compression}" | ||
) as arc: | ||
# Remove any partially-completed prior extraction. | ||
shutil.rmtree(self.site_data_path / "contents", ignore_errors=True) | ||
|
||
arc.extractall(str(self.site_data_path / "contents")) | ||
|
||
self._tar_extracted_marker.touch() | ||
|
||
if self.strip_prefix and not self.dataset_root.is_dir(): | ||
raise FileNotFoundError( | ||
f"Directory prefix '{self.strip_prefix}' not found in dataset '{self.name}'" | ||
) | ||
|
||
|
||
class TarDatasetWithManifest(TarDataset): | ||
"""A tarball-based dataset that has a manifest file which lists the URIs. | ||
The idea is to allow the list of benchmark URIs to be enumerated in a more | ||
lightweight manner than downloading and unpacking the entire dataset. It | ||
does this by downloading a "manifest", which is a list of benchmark names, | ||
and only downloads the actual tarball containing the benchmarks when it is | ||
needed. | ||
The manifest is assumed to be correct and is not validated. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
manifest_urls: List[str], | ||
manifest_sha256: str, | ||
manifest_compression: str = "bz2", | ||
**dataset_args, | ||
): | ||
"""Constructor. | ||
:param manifest_urls: A list of redundant URLS to download the | ||
compressed text file containing a list of benchmark URI suffixes, | ||
one per line. | ||
:param manifest_sha256: The sha256 checksum of the compressed manifest | ||
file. | ||
:param manifest_compression: The manifest compression type. One of | ||
{"bz2", "gz"}. | ||
:param dataset_args: See :meth:`TarDataset.__init__() | ||
<compiler_gym.datasets.TarDataset.__init__>`. | ||
""" | ||
super().__init__(**dataset_args) | ||
self.manifest_urls = manifest_urls | ||
self.manifest_sha256 = manifest_sha256 | ||
self.manifest_compression = manifest_compression | ||
self._manifest_path = self.site_data_path / f"manifest-{manifest_sha256}.txt" | ||
|
||
self._manifest_lock = Lock() | ||
self._manifest_lockfile = self.site_data_path / "manifest.LOCK" | ||
|
||
def _read_manifest_file(self) -> List[str]: | ||
"""Read the manifest file from disk. | ||
Does not check that the manifest file exists. | ||
""" | ||
with open(self._manifest_path, encoding="utf-8") as f: | ||
lines = f.read().rstrip().split("\n") | ||
uris = [f"{self.name}/{line}" for line in lines] | ||
self.logger.debug("Read %s manifest, %d entries", self.name, len(uris)) | ||
return uris | ||
|
||
@memoized_property | ||
def _benchmark_uris(self) -> List[str]: | ||
"""Fetch or download the URI list.""" | ||
if self._manifest_path.is_file(): | ||
return self._read_manifest_file() | ||
|
||
with self._manifest_lock: | ||
with fasteners.InterProcessLock(self._manifest_lockfile): | ||
# Now that we have acquired the lock, repeat the check. | ||
if self._manifest_path.is_file(): | ||
return self._read_manifest_file() | ||
|
||
# Decompress the manifest data. | ||
self.logger.debug("Downloading %s manifest", self.name) | ||
manifest_data = io.BytesIO( | ||
download(self.manifest_urls, self.manifest_sha256) | ||
) | ||
if self.manifest_compression == "bz2": | ||
with bz2.BZ2File(manifest_data) as f: | ||
manifest_data = f.read() | ||
elif self.manifest_compression == "gz": | ||
with gzip.GzipFile(fileobj=manifest_data) as f: | ||
manifest_data = f.read() | ||
else: | ||
raise TypeError( | ||
f"Unknown manifest compression: {self.manifest_compression}" | ||
) | ||
|
||
# Write to a temporary file and rename to avoid a race. | ||
with open(f"{self._manifest_path}.tmp", "wb") as f: | ||
f.write(manifest_data) | ||
os.rename(f"{self._manifest_path}.tmp", self._manifest_path) | ||
|
||
lines = manifest_data.decode("utf-8").rstrip().split("\n") | ||
uris = [f"{self.name}/{line}" for line in lines] | ||
self.logger.debug( | ||
"Downloaded %s manifest, %d entries", self.name, len(lines) | ||
) | ||
return uris | ||
|
||
@memoized_property | ||
def n(self) -> int: | ||
return len(self._benchmark_uris) | ||
|
||
def benchmark_uris(self) -> Iterable[str]: | ||
yield from iter(self._benchmark_uris) |
Oops, something went wrong.