Skip to content

Commit

Permalink
[datasets] Add Dataset helper subclasses.
Browse files Browse the repository at this point in the history
Issue #45.
  • Loading branch information
ChrisCummins committed Apr 22, 2021
1 parent 75730f1 commit 7b86fde
Show file tree
Hide file tree
Showing 6 changed files with 527 additions and 0 deletions.
2 changes: 2 additions & 0 deletions compiler_gym/datasets/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ py_library(
"__init__.py",
"benchmark.py",
"dataset.py",
"files_dataset.py",
"tar_dataset.py",
],
visibility = ["//visibility:public"],
deps = [
Expand Down
5 changes: 5 additions & 0 deletions compiler_gym/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
delete,
require,
)
from compiler_gym.datasets.files_dataset import FilesDataset
from compiler_gym.datasets.tar_dataset import TarDataset, TarDatasetWithManifest

__all__ = [
"activate",
Expand All @@ -27,6 +29,9 @@
"DatasetInitError",
"deactivate",
"delete",
"FilesDataset",
"LegacyDataset",
"require",
"TarDataset",
"TarDatasetWithManifest",
]
156 changes: 156 additions & 0 deletions compiler_gym/datasets/files_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from pathlib import Path
from typing import Iterable, List, Optional

from compiler_gym.datasets.dataset import Benchmark, Dataset
from compiler_gym.util.decorators import memoized_property


class FilesDataset(Dataset):
"""A dataset comprising a directory tree files.
A FilesDataset is a root directory that contains (possibly nested) files,
where each file represents a benchmark. Files can be filtered on their
expected filename suffix.
Every file that matches the expected suffix is a benchmark. The URI of the
benchmarks is the relative path of the file within the dataset, stripped of
any matching file suffix. For example, given a dataset root around the
directory tree :code:`/tmp/dataset` and with file suffix :code:`.txt`:
.. code-block::
/tmp/dataset/a.txt
/tmp/dataset/LICENSE
/tmp/dataset/subdir/subdir/b.txt
/tmp/dataset/subdir/subdir/c.txt
a FilesDataset representing this directory tree will contain the following
URIs:
>>> list(dataset.benchmark_uris())
[
"benchamrk://ds-v0/a",
"benchamrk://ds-v0/subdir/subdir/b",
"benchamrk://ds-v0/subdir/subdir/c",
]
"""

def __init__(
self,
dataset_root: Path,
benchmark_file_suffix: str = "",
memoize_uris: bool = True,
**dataset_args,
):
"""Constructor.
:param dataset_root: The root directory to look for benchmark files.
:param benchmark_file_suffix: A file extension that must be matched for
a file to be used as a benchmark.
:param memoize_uris: Whether to memoize the list of URIs contained in
the dataset. Memoizing the URIs is a tradeoff between *O(n)*
computation complexity of random access vs *O(n)* space complexity
of memoizing the URI list.
:param dataset_args: See :meth:`Dataset.__init__()
<compiler_gym.datasets.Dataset.__init__>`.
"""
super().__init__(**dataset_args)
self.dataset_root = dataset_root
self.benchmark_file_suffix = benchmark_file_suffix
self.memoize_uris = memoize_uris
self._memoized_uris = None

@memoized_property
def n(self) -> int: # pylint: disable=invalid-overriden-method
self.install()
return sum(
sum(1 for f in files if f.endswith(self.benchmark_file_suffix))
for (_, _, files) in os.walk(self.dataset_root)
)

@property
def _benchmark_uris_iter(self) -> Iterable[str]:
"""Return an iterator over benchmark URIs is consistent across runs."""
self.install()
for root, dirs, files in os.walk(self.dataset_root):
dirs.sort()
reldir = root[len(str(self.dataset_root)) + 1 :]
for filename in sorted(files):
# If we have an expected file suffix then filter on it and strip
# it from the filename.
if self.benchmark_file_suffix:
if not filename.endswith(self.benchmark_file_suffix):
continue
filename = filename[: -len(self.benchmark_file_suffix)]
# Use os.path.join() rather than simple '/' concaentation as
# reldir may be empty.
yield os.path.join(self.name, reldir, filename)

@property
def _benchmark_uris(self) -> List[str]:
return list(self._benchmark_uris_iter)

def benchmark_uris(self) -> Iterable[str]:
if self._memoized_uris:
yield from self._memoized_uris
elif self.memoize_uris:
self._memoized_uris = self._benchmark_uris
yield from self._memoized_uris
else:
yield from self._benchmark_uris_iter

def benchmark(self, uri: Optional[str] = None) -> Benchmark:
self.install()
if uri is None or len(uri) <= len(self.name) + 1:
if not self.n:
raise ValueError("No benchmarks")
return self.get_benchmark_by_index(self.random.integers(self.n))

relpath = f"{uri[len(self.name) + 1:]}{self.benchmark_file_suffix}"
abspath = self.dataset_root / relpath
if not abspath.is_file():
raise LookupError(f"Benchmark not found: {uri} (file not found: {abspath})")
return self.benchmark_class.from_file(uri, abspath)

def get_benchmark_by_index(self, n: int) -> Benchmark:
"""Look a benchmark using a numeric index into the list of URIs."""
# If we have memoized the benchmark IDs then just index into the list.
# Otherwise we will scan through the directory hierarchy.
if self.memoize_uris:
if not self._memoized_uris:
self._memoized_uris = self._benchmark_uris
return self.benchmark(self._memoized_uris[n])

i = 0
for root, dirs, files in os.walk(self.dataset_root):
reldir = root[len(str(self.dataset_root)) + 1 :]

# Filter only the files that match the target file suffix.
valid_files = [f for f in files if f.endswith(self.benchmark_file_suffix)]

if i + len(valid_files) < n:
# There aren't enough files in this directory to bring us up to
# the target file index, so skip the whole lot.
i += len(valid_files)
dirs.sort()
else:
# Iterate through the files in the directory in order until we
# reach the target index.
for filename in sorted(valid_files):
if i == n:
name_stem = filename[: -len(self.benchmark_file_suffix)]
# Use os.path.join() rather than simple '/' concaentation as
# reldir may be empty.
uri = os.path.join(self.name, reldir, name_stem)
return self.benchmark_class.from_file(uri, f"{root}/{filename}")
i += 1

raise FileNotFoundError(f"Could not find benchmark with index {n} / {self.n}")
204 changes: 204 additions & 0 deletions compiler_gym/datasets/tar_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bz2
import gzip
import io
import os
import shutil
import tarfile
from threading import Lock
from typing import Iterable, List, Optional

import fasteners

from compiler_gym.datasets.files_dataset import FilesDataset
from compiler_gym.util.decorators import memoized_property
from compiler_gym.util.download import download


class TarDataset(FilesDataset):
"""A dataset comprising a files tree stored in a tar archive.
This extends the :class:`FilesDataset` class by adding support for
compressed archives of files. The archive is downloaded and unpacked
on-demand.
"""

def __init__(
self,
tar_urls: List[str],
tar_sha256: Optional[str] = None,
tar_compression: str = "bz2",
strip_prefix: str = "",
**dataset_args,
):
"""Constructor.
:param tar_urls: A list of redundant URLS to download the tar archive from.
:param tar_sha256: The SHA256 checksum of the downloaded tar archive.
:param tar_compression: The tar archive compression type. One of
{"bz2", "gz"}.
:param strip_prefix: An optional path prefix to strip. Only files that
match this path prefix will be used as benchmarks.
:param dataset_args: See :meth:`FilesDataset.__init__()
<compiler_gym.datasets.FilesDataset.__init__>`.
"""
super().__init__(
dataset_root=None, # Set below once site_data_path is resolved.
**dataset_args,
)
self.dataset_root = self.site_data_path / "contents" / strip_prefix

self.tar_urls = tar_urls
self.tar_sha256 = tar_sha256
self.tar_compression = tar_compression
self.strip_prefix = strip_prefix

self._installed = False
self._tar_extracted_marker = self.site_data_path / ".extracted"
self._tar_lock = Lock()
self._tar_lockfile = self.site_data_path / "LOCK"

@property
def installed(self) -> bool:
# Fast path for repeated checks to 'installed' without a disk op.
if not self._installed:
self._installed = self._tar_extracted_marker.is_file()
return self._installed

def install(self) -> None:
if self.installed:
return

# Thread-level and process-level locks to prevent races.
with self._tar_lock, fasteners.InterProcessLock(self._tar_lockfile):
# Repeat the check to see if we have already installed the
# dataset now that we have acquired the lock.
if self.installed:
return

self.logger.info("Downloading %s dataset", self.name)
tar_data = io.BytesIO(download(self.tar_urls, self.tar_sha256))
self.logger.info("Unpacking %s dataset", self.name)
with tarfile.open(
fileobj=tar_data, mode=f"r:{self.tar_compression}"
) as arc:
# Remove any partially-completed prior extraction.
shutil.rmtree(self.site_data_path / "contents", ignore_errors=True)

arc.extractall(str(self.site_data_path / "contents"))

self._tar_extracted_marker.touch()

if self.strip_prefix and not self.dataset_root.is_dir():
raise FileNotFoundError(
f"Directory prefix '{self.strip_prefix}' not found in dataset '{self.name}'"
)


class TarDatasetWithManifest(TarDataset):
"""A tarball-based dataset that has a manifest file which lists the URIs.
The idea is to allow the list of benchmark URIs to be enumerated in a more
lightweight manner than downloading and unpacking the entire dataset. It
does this by downloading a "manifest", which is a list of benchmark names,
and only downloads the actual tarball containing the benchmarks when it is
needed.
The manifest is assumed to be correct and is not validated.
"""

def __init__(
self,
manifest_urls: List[str],
manifest_sha256: str,
manifest_compression: str = "bz2",
**dataset_args,
):
"""Constructor.
:param manifest_urls: A list of redundant URLS to download the
compressed text file containing a list of benchmark URI suffixes,
one per line.
:param manifest_sha256: The sha256 checksum of the compressed manifest
file.
:param manifest_compression: The manifest compression type. One of
{"bz2", "gz"}.
:param dataset_args: See :meth:`TarDataset.__init__()
<compiler_gym.datasets.TarDataset.__init__>`.
"""
super().__init__(**dataset_args)
self.manifest_urls = manifest_urls
self.manifest_sha256 = manifest_sha256
self.manifest_compression = manifest_compression
self._manifest_path = self.site_data_path / f"manifest-{manifest_sha256}.txt"

self._manifest_lock = Lock()
self._manifest_lockfile = self.site_data_path / "manifest.LOCK"

def _read_manifest_file(self) -> List[str]:
"""Read the manifest file from disk.
Does not check that the manifest file exists.
"""
with open(self._manifest_path, encoding="utf-8") as f:
lines = f.read().rstrip().split("\n")
uris = [f"{self.name}/{line}" for line in lines]
self.logger.debug("Read %s manifest, %d entries", self.name, len(uris))
return uris

@memoized_property
def _benchmark_uris(self) -> List[str]:
"""Fetch or download the URI list."""
if self._manifest_path.is_file():
return self._read_manifest_file()

with self._manifest_lock:
with fasteners.InterProcessLock(self._manifest_lockfile):
# Now that we have acquired the lock, repeat the check.
if self._manifest_path.is_file():
return self._read_manifest_file()

# Decompress the manifest data.
self.logger.debug("Downloading %s manifest", self.name)
manifest_data = io.BytesIO(
download(self.manifest_urls, self.manifest_sha256)
)
if self.manifest_compression == "bz2":
with bz2.BZ2File(manifest_data) as f:
manifest_data = f.read()
elif self.manifest_compression == "gz":
with gzip.GzipFile(fileobj=manifest_data) as f:
manifest_data = f.read()
else:
raise TypeError(
f"Unknown manifest compression: {self.manifest_compression}"
)

# Write to a temporary file and rename to avoid a race.
with open(f"{self._manifest_path}.tmp", "wb") as f:
f.write(manifest_data)
os.rename(f"{self._manifest_path}.tmp", self._manifest_path)

lines = manifest_data.decode("utf-8").rstrip().split("\n")
uris = [f"{self.name}/{line}" for line in lines]
self.logger.debug(
"Downloaded %s manifest, %d entries", self.name, len(lines)
)
return uris

@memoized_property
def n(self) -> int:
return len(self._benchmark_uris)

def benchmark_uris(self) -> Iterable[str]:
yield from iter(self._benchmark_uris)
Loading

0 comments on commit 7b86fde

Please sign in to comment.