Skip to content

Commit

Permalink
Merge pull request #7 from nicholasjng/multifile-artifacts
Browse files Browse the repository at this point in the history
Add support for handling multi-file artifacts
  • Loading branch information
nicholasjng authored Dec 24, 2023
2 parents 39d3088 + e786812 commit 7061296
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/shelf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
pass


from .core import Shelf
from .registry import deregister_type, lookup, register_type
from .shelf import Shelf
93 changes: 49 additions & 44 deletions src/shelf/shelf.py → src/shelf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,50 @@
import os
import tempfile
from os import PathLike
from pathlib import Path
from typing import Any, Literal, TypeVar

from fsspec import AbstractFileSystem, filesystem
from fsspec.utils import get_protocol, stringify_path
from fsspec.utils import get_protocol

import shelf.registry as registry
from shelf.util import is_fully_qualified
import shelf.registry
from shelf.util import is_fully_qualified, with_trailing_sep

T = TypeVar("T")


def load_config(filename: str | Path) -> dict[str, Any]:
def get_project_root() -> Path:
"""
Returns project root if currently in a project (sub-)folder,
otherwise the current directory.
"""
cwd = Path.cwd()
for p in (cwd, *cwd.parents):
if (p / "setup.py").exists() or (p / "pyproject.toml").exists():
return p
return cwd

config: dict[str, Any] = {}

for loc in [Path.home(), get_project_root()]:
if (pp := loc / filename).exists():
with open(pp, "r") as f:
import yaml

config = yaml.safe_load(f)
return config


class Shelf:
def __init__(
self,
prefix: str | os.PathLike[str] = "",
cache_dir: str | PathLike[str] | None = None,
cache_type: Literal["blockcache", "filecache", "simplecache"] = "filecache",
fsconfig: dict[str, dict[str, Any]] | None = None,
configfile: str | PathLike[str] | None = None,
):
self.prefix = str(prefix)

self.cache_type = cache_type
self.cache_dir = cache_dir

# config object holding storage options for file systems
self.fsconfig = fsconfig or {}
# TODO: Validate schema for inputs
if configfile and not fsconfig:
import yaml

with open(configfile, "r") as f:
self.fsconfig = yaml.safe_load(f)
else:
self.fsconfig = fsconfig or {}

def get(self, rpath: str, expected_type: type[T]) -> T:
# load machinery early, so that we do not download
# if the type is not registered.
serde = shelf.registry.lookup(expected_type)

if not is_fully_qualified(rpath):
rpath = os.path.join(self.prefix, rpath)

# load machinery early, so that we do not download
# if the type is not registered.
serde = registry.lookup(expected_type)
protocol = get_protocol(rpath)

# file system-specific options.
config = self.fsconfig.get(protocol, {})
storage_options = config.get("storage", {})
Expand All @@ -81,27 +65,43 @@ def get(self, rpath: str, expected_type: type[T]) -> T:

fs: AbstractFileSystem = filesystem(proto, **kwargs)

download_options = config.get("download", {})
try:
rfiles = fs.ls(rpath, detail=False)
# some file systems (e.g. local) don't allow filenames in `ls`
except NotADirectoryError:
rfiles = [fs.info(rpath)["name"]]

if not rfiles:
raise FileNotFoundError(rpath)

with contextlib.ExitStack() as stack:
tmpdir = stack.enter_context(tempfile.TemporaryDirectory())
# TODO: Push a unique directory (e.g. checksum) in front to
# create a directory

# explicit file lists have the side effect that remote subdirectory structures
# are flattened.
lfiles = [os.path.join(tmpdir, os.path.basename(f)) for f in rfiles]

# trailing slash tells fsspec to download files into `lpath`
lpath = stringify_path(tmpdir.rstrip(os.sep) + os.sep)
fs.get(rpath, lpath, **download_options)
download_options = config.get("download", {})
fs.get(rfiles, lfiles, **download_options)

# TODO: Find a way to pass files in expected order
files = [str(p) for p in Path(tmpdir).iterdir() if p.is_file()]
if not files:
raise ValueError(f"no files found for rpath {rpath!r}")
obj: T = serde.deserializer(*files)
# TODO: Support deserializer interfaces taking unraveled tuples, i.e. filenames
# as arguments in the multifile case
lpath: str | tuple[str, ...]
if len(lfiles) == 1:
lpath = lfiles[0]
else:
lpath = tuple(lfiles)

obj: T = serde.deserializer(lpath)

return obj

def put(self, obj: T, rpath: str) -> None:
# load machinery early, so that we do not download
# if the type is not registered.
serde = registry.lookup(type(obj))
serde = shelf.registry.lookup(type(obj))

if not is_fully_qualified(rpath):
rpath = os.path.join(self.prefix, rpath)
Expand All @@ -127,10 +127,15 @@ def put(self, obj: T, rpath: str) -> None:

with contextlib.ExitStack() as stack:
tmpdir = stack.enter_context(tempfile.TemporaryDirectory())
# TODO: What about multiple lpaths?
lpath = serde.serializer(obj, tmpdir)

recursive = isinstance(lpath, (list, tuple))
if recursive:
# signals fsspec to put all files into rpath directory
rpath = with_trailing_sep(rpath)

upload_options = fsconfig.get("upload", {})
fs.put(lpath, rpath, **upload_options)
# TODO: Construct explicit lists always to hit the fast path of fs.put()
fs.put(lpath, rpath, recursive=recursive, **upload_options)

return fs.info(rpath)
6 changes: 6 additions & 0 deletions src/shelf/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os

from fsspec.utils import get_protocol, stringify_path


def is_fully_qualified(path: str) -> bool:
path = stringify_path(path)
protocol = get_protocol(path)
return any(path.startswith(protocol + sep) for sep in ("::", "://"))


def with_trailing_sep(path: str) -> str:
return path if path.endswith(os.sep) else path + os.sep
16 changes: 13 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Generator
from pathlib import Path
from typing import Any, Generator

import pytest
import yaml

from shelf.registry import _registry
import shelf.registry

testdir = Path(__file__).parent


@pytest.fixture(autouse=True)
Expand All @@ -11,4 +15,10 @@ def empty_registry() -> Generator[None, None, None]:
try:
yield
finally:
_registry.clear()
shelf.registry._registry.clear()


@pytest.fixture(scope="session")
def fsconfig() -> dict[str, dict[str, Any]]:
with open(testdir / "shelfconfig.yaml", "r") as f:
return yaml.safe_load(f)
3 changes: 3 additions & 0 deletions tests/shelfconfig.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
file:
storage:
auto_mkdir: true
36 changes: 36 additions & 0 deletions tests/test_shelf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,39 @@ def json_load(fname: str) -> dict:
data2 = s.get("myobj.json", dict)

assert data == data2


def test_multifile_artifact(tmp_path: Path, fsconfig: dict) -> None:
"""
Test a dict artifact JSON roundtrip with the dict serialized into two different files.
No nested directories, only multiple filenames.
"""

def json_dump(d: dict, tmpdir: str) -> tuple[str, ...]:
d1, d2 = {"a": d["a"]}, {"b": d["b"]}
fnames = []
for i, d in enumerate((d1, d2)):
fname = os.path.join(tmpdir, f"dump{i}.json")
fnames.append(fname)
with open(fname, "w") as f:
json.dump(d, f)
return tuple(fnames)

def json_load(fnames: tuple[str, str]) -> dict:
d: dict = {}
for fname in fnames:
with open(fname, "r") as f:
d |= json.load(f)
return d

shelf.register_type(dict, json_dump, json_load)

s = shelf.Shelf(prefix=tmp_path, fsconfig=fsconfig)

data = {"a": 1, "b": 2}

s.put(data, "myobj")
data2 = s.get("myobj", dict)

assert data == data2

0 comments on commit 7061296

Please sign in to comment.