diff --git a/src/shelf/__init__.py b/src/shelf/__init__.py index 4b3f17a..5a19700 100644 --- a/src/shelf/__init__.py +++ b/src/shelf/__init__.py @@ -7,5 +7,5 @@ pass +from .core import Shelf from .registry import deregister_type, lookup, register_type -from .shelf import Shelf diff --git a/src/shelf/shelf.py b/src/shelf/core.py similarity index 58% rename from src/shelf/shelf.py rename to src/shelf/core.py index c598567..2f323c9 100644 --- a/src/shelf/shelf.py +++ b/src/shelf/core.py @@ -4,41 +4,17 @@ import os import tempfile from os import PathLike -from pathlib import Path from typing import Any, Literal, TypeVar from fsspec import AbstractFileSystem, filesystem -from fsspec.utils import get_protocol, stringify_path +from fsspec.utils import get_protocol -import shelf.registry as registry -from shelf.util import is_fully_qualified +import shelf.registry +from shelf.util import is_fully_qualified, with_trailing_sep T = TypeVar("T") -def load_config(filename: str | Path) -> dict[str, Any]: - def get_project_root() -> Path: - """ - Returns project root if currently in a project (sub-)folder, - otherwise the current directory. - """ - cwd = Path.cwd() - for p in (cwd, *cwd.parents): - if (p / "setup.py").exists() or (p / "pyproject.toml").exists(): - return p - return cwd - - config: dict[str, Any] = {} - - for loc in [Path.home(), get_project_root()]: - if (pp := loc / filename).exists(): - with open(pp, "r") as f: - import yaml - - config = yaml.safe_load(f) - return config - - class Shelf: def __init__( self, @@ -46,6 +22,7 @@ def __init__( cache_dir: str | PathLike[str] | None = None, cache_type: Literal["blockcache", "filecache", "simplecache"] = "filecache", fsconfig: dict[str, dict[str, Any]] | None = None, + configfile: str | PathLike[str] | None = None, ): self.prefix = str(prefix) @@ -53,17 +30,24 @@ def __init__( self.cache_dir = cache_dir # config object holding storage options for file systems - self.fsconfig = fsconfig or {} + # TODO: Validate schema for inputs + if configfile and not fsconfig: + import yaml + + with open(configfile, "r") as f: + self.fsconfig = yaml.safe_load(f) + else: + self.fsconfig = fsconfig or {} def get(self, rpath: str, expected_type: type[T]) -> T: + # load machinery early, so that we do not download + # if the type is not registered. + serde = shelf.registry.lookup(expected_type) + if not is_fully_qualified(rpath): rpath = os.path.join(self.prefix, rpath) - # load machinery early, so that we do not download - # if the type is not registered. - serde = registry.lookup(expected_type) protocol = get_protocol(rpath) - # file system-specific options. config = self.fsconfig.get(protocol, {}) storage_options = config.get("storage", {}) @@ -81,27 +65,43 @@ def get(self, rpath: str, expected_type: type[T]) -> T: fs: AbstractFileSystem = filesystem(proto, **kwargs) - download_options = config.get("download", {}) + try: + rfiles = fs.ls(rpath, detail=False) + # some file systems (e.g. local) don't allow filenames in `ls` + except NotADirectoryError: + rfiles = [fs.info(rpath)["name"]] + + if not rfiles: + raise FileNotFoundError(rpath) with contextlib.ExitStack() as stack: tmpdir = stack.enter_context(tempfile.TemporaryDirectory()) + # TODO: Push a unique directory (e.g. checksum) in front to + # create a directory + + # explicit file lists have the side effect that remote subdirectory structures + # are flattened. + lfiles = [os.path.join(tmpdir, os.path.basename(f)) for f in rfiles] - # trailing slash tells fsspec to download files into `lpath` - lpath = stringify_path(tmpdir.rstrip(os.sep) + os.sep) - fs.get(rpath, lpath, **download_options) + download_options = config.get("download", {}) + fs.get(rfiles, lfiles, **download_options) - # TODO: Find a way to pass files in expected order - files = [str(p) for p in Path(tmpdir).iterdir() if p.is_file()] - if not files: - raise ValueError(f"no files found for rpath {rpath!r}") - obj: T = serde.deserializer(*files) + # TODO: Support deserializer interfaces taking unraveled tuples, i.e. filenames + # as arguments in the multifile case + lpath: str | tuple[str, ...] + if len(lfiles) == 1: + lpath = lfiles[0] + else: + lpath = tuple(lfiles) + + obj: T = serde.deserializer(lpath) return obj def put(self, obj: T, rpath: str) -> None: # load machinery early, so that we do not download # if the type is not registered. - serde = registry.lookup(type(obj)) + serde = shelf.registry.lookup(type(obj)) if not is_fully_qualified(rpath): rpath = os.path.join(self.prefix, rpath) @@ -127,10 +127,15 @@ def put(self, obj: T, rpath: str) -> None: with contextlib.ExitStack() as stack: tmpdir = stack.enter_context(tempfile.TemporaryDirectory()) - # TODO: What about multiple lpaths? lpath = serde.serializer(obj, tmpdir) + recursive = isinstance(lpath, (list, tuple)) + if recursive: + # signals fsspec to put all files into rpath directory + rpath = with_trailing_sep(rpath) + upload_options = fsconfig.get("upload", {}) - fs.put(lpath, rpath, **upload_options) + # TODO: Construct explicit lists always to hit the fast path of fs.put() + fs.put(lpath, rpath, recursive=recursive, **upload_options) return fs.info(rpath) diff --git a/src/shelf/util.py b/src/shelf/util.py index f272d7b..6dc5a2a 100644 --- a/src/shelf/util.py +++ b/src/shelf/util.py @@ -1,3 +1,5 @@ +import os + from fsspec.utils import get_protocol, stringify_path @@ -5,3 +7,7 @@ def is_fully_qualified(path: str) -> bool: path = stringify_path(path) protocol = get_protocol(path) return any(path.startswith(protocol + sep) for sep in ("::", "://")) + + +def with_trailing_sep(path: str) -> str: + return path if path.endswith(os.sep) else path + os.sep diff --git a/tests/conftest.py b/tests/conftest.py index c1b9b80..185d59e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,12 @@ -from typing import Generator +from pathlib import Path +from typing import Any, Generator import pytest +import yaml -from shelf.registry import _registry +import shelf.registry + +testdir = Path(__file__).parent @pytest.fixture(autouse=True) @@ -11,4 +15,10 @@ def empty_registry() -> Generator[None, None, None]: try: yield finally: - _registry.clear() + shelf.registry._registry.clear() + + +@pytest.fixture(scope="session") +def fsconfig() -> dict[str, dict[str, Any]]: + with open(testdir / "shelfconfig.yaml", "r") as f: + return yaml.safe_load(f) diff --git a/tests/shelfconfig.yaml b/tests/shelfconfig.yaml new file mode 100644 index 0000000..dea7951 --- /dev/null +++ b/tests/shelfconfig.yaml @@ -0,0 +1,3 @@ +file: + storage: + auto_mkdir: true diff --git a/tests/test_shelf.py b/tests/test_shelf.py index 3338926..c2ec832 100644 --- a/tests/test_shelf.py +++ b/tests/test_shelf.py @@ -28,3 +28,39 @@ def json_load(fname: str) -> dict: data2 = s.get("myobj.json", dict) assert data == data2 + + +def test_multifile_artifact(tmp_path: Path, fsconfig: dict) -> None: + """ + Test a dict artifact JSON roundtrip with the dict serialized into two different files. + + No nested directories, only multiple filenames. + """ + + def json_dump(d: dict, tmpdir: str) -> tuple[str, ...]: + d1, d2 = {"a": d["a"]}, {"b": d["b"]} + fnames = [] + for i, d in enumerate((d1, d2)): + fname = os.path.join(tmpdir, f"dump{i}.json") + fnames.append(fname) + with open(fname, "w") as f: + json.dump(d, f) + return tuple(fnames) + + def json_load(fnames: tuple[str, str]) -> dict: + d: dict = {} + for fname in fnames: + with open(fname, "r") as f: + d |= json.load(f) + return d + + shelf.register_type(dict, json_dump, json_load) + + s = shelf.Shelf(prefix=tmp_path, fsconfig=fsconfig) + + data = {"a": 1, "b": 2} + + s.put(data, "myobj") + data2 = s.get("myobj", dict) + + assert data == data2