diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index 16cd925ce..feeeb9665 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -102,7 +102,7 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]: if "container_path" in field.metadata: continue inp_dict[field.name] = getattr(self, field.name) - hash_cache = Cache({}) + hash_cache = Cache() field_hashes = { k: hash_function(v, cache=hash_cache) for k, v in inp_dict.items() } diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py index e2610c9bd..dc107a130 100644 --- a/pydra/engine/submitter.py +++ b/pydra/engine/submitter.py @@ -7,6 +7,7 @@ from .workers import Worker, WORKERS from .core import is_workflow from .helpers import get_open_loop, load_and_run_async +from ..utils.hash import PersistentCache import logging @@ -54,6 +55,7 @@ def __call__(self, runnable, cache_locations=None, rerun=False, environment=None self.loop.run_until_complete( self.submit_from_call(runnable, rerun, environment) ) + PersistentCache().clean_up() return runnable.result() async def submit_from_call(self, runnable, rerun, environment): diff --git a/pydra/engine/tests/test_node_task.py b/pydra/engine/tests/test_node_task.py index 37ed90d03..bceaf9740 100644 --- a/pydra/engine/tests/test_node_task.py +++ b/pydra/engine/tests/test_node_task.py @@ -1,8 +1,15 @@ import os import shutil import attr +import typing as ty import numpy as np +import time +from unittest import mock +from pathlib import Path import pytest +import time +from fileformats.generic import File +import pydra.mark from .utils import ( fun_addtwo, @@ -306,6 +313,7 @@ def test_task_init_7(tmp_path): output_dir1 = nn1.output_dir # changing the content of the file + time.sleep(2) # need the mtime to be different file2 = tmp_path / "file2.txt" with open(file2, "w") as f: f.write("from pydra") @@ -1560,3 +1568,98 @@ def test_task_state_cachelocations_updated(plugin, tmp_path): # both workflows should be run assert all([dir.exists() for dir in nn.output_dir]) assert all([dir.exists() for dir in nn2.output_dir]) + + +def test_task_files_cachelocations(plugin_dask_opt, tmp_path): + """ + Two identical tasks with provided cache_dir that use file as an input; + the second task has cache_locations and should not recompute the results + """ + cache_dir = tmp_path / "test_task_nostate" + cache_dir.mkdir() + cache_dir2 = tmp_path / "test_task_nostate2" + cache_dir2.mkdir() + input_dir = tmp_path / "input" + input_dir.mkdir() + + input1 = input_dir / "input1.txt" + input1.write_text("test") + input2 = input_dir / "input2.txt" + input2.write_text("test") + + nn = fun_file(name="NA", filename=input1, cache_dir=cache_dir) + with Submitter(plugin=plugin_dask_opt) as sub: + sub(nn) + + nn2 = fun_file( + name="NA", filename=input2, cache_dir=cache_dir2, cache_locations=cache_dir + ) + with Submitter(plugin=plugin_dask_opt) as sub: + sub(nn2) + + # checking the results + results2 = nn2.result() + assert results2.output.out == "test" + + # checking if the second task didn't run the interface again + assert nn.output_dir.exists() + assert not nn2.output_dir.exists() + + +class OverriddenContentsFile(File): + """A class for testing purposes, to that enables you to override the contents + of the file to allow you to check whether the persistent cache is used.""" + + def __init__( + self, + fspaths: ty.Iterator[Path], + contents: ty.Optional[bytes] = None, + metadata: ty.Dict[str, ty.Any] = None, + ): + super().__init__(fspaths, metadata=metadata) + self._contents = contents + + def byte_chunks(self, **kwargs) -> ty.Generator[ty.Tuple[str, bytes], None, None]: + if self._contents is not None: + yield (str(self.fspath), iter([self._contents])) + else: + yield from super().byte_chunks(**kwargs) + + @property + def contents(self): + if self._contents is not None: + return self._contents + return super().contents + + +def test_task_files_persistentcache(tmp_path): + """ + Two identical tasks with provided cache_dir that use file as an input; + the second task has cache_locations and should not recompute the results + """ + test_file_path = tmp_path / "test_file.txt" + test_file_path.write_bytes(b"foo") + cache_dir = tmp_path / "cache-dir" + cache_dir.mkdir() + test_file = OverriddenContentsFile(test_file_path) + + @pydra.mark.task + def read_contents(x: OverriddenContentsFile) -> bytes: + return x.contents + + assert ( + read_contents(x=test_file, cache_dir=cache_dir)(plugin="serial").output.out + == b"foo" + ) + test_file._contents = b"bar" + # should return result from the first run using the persistent cache + assert ( + read_contents(x=test_file, cache_dir=cache_dir)(plugin="serial").output.out + == b"foo" + ) + time.sleep(2) # Windows has a 2-second resolution for mtime + test_file_path.touch() # update the mtime to invalidate the persistent cache value + assert ( + read_contents(x=test_file, cache_dir=cache_dir)(plugin="serial").output.out + == b"bar" + ) # returns the overridden value diff --git a/pydra/engine/tests/test_specs.py b/pydra/engine/tests/test_specs.py index f7edd0f57..4f54cd404 100644 --- a/pydra/engine/tests/test_specs.py +++ b/pydra/engine/tests/test_specs.py @@ -3,6 +3,7 @@ import os import attrs from copy import deepcopy +import time from ..specs import ( BaseSpec, @@ -163,6 +164,7 @@ def test_input_file_hash_2(tmp_path): assert hash1 == hash2 # checking if different content (the same name) affects the hash + time.sleep(2) # ensure mtime is different file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") @@ -193,6 +195,7 @@ def test_input_file_hash_2a(tmp_path): assert hash1 == hash2 # checking if different content (the same name) affects the hash + time.sleep(2) # ensure mtime is different file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") @@ -234,6 +237,7 @@ def test_input_file_hash_3(tmp_path): # assert id(files_hash1["in_file"][filename]) == id(files_hash2["in_file"][filename]) # recreating the file + time.sleep(2) # ensure mtime is different with open(file, "w") as f: f.write("hello") @@ -288,6 +292,7 @@ def test_input_file_hash_4(tmp_path): assert hash1 == hash2 # checking if different content (the same name) affects the hash + time.sleep(2) # need the mtime to be different file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") @@ -324,6 +329,7 @@ def test_input_file_hash_5(tmp_path): assert hash1 == hash2 # checking if different content (the same name) affects the hash + time.sleep(2) # ensure mtime is different file_diffcontent = tmp_path / "in_file_1.txt" with open(file_diffcontent, "w") as f: f.write("hi") diff --git a/pydra/utils/__init__.py b/pydra/utils/__init__.py index e69de29bb..7fe4b8595 100644 --- a/pydra/utils/__init__.py +++ b/pydra/utils/__init__.py @@ -0,0 +1,11 @@ +from pathlib import Path +import platformdirs +from pydra._version import __version__ + +user_cache_dir = Path( + platformdirs.user_cache_dir( + appname="pydra", + appauthor="nipype", + version=__version__, + ) +) diff --git a/pydra/utils/hash.py b/pydra/utils/hash.py index 74d3b3a44..90e132d1e 100644 --- a/pydra/utils/hash.py +++ b/pydra/utils/hash.py @@ -1,16 +1,14 @@ """Generic object hashing dispatch""" import os - -# import stat import struct +from datetime import datetime import typing as ty +from pathlib import Path from collections.abc import Mapping from functools import singledispatch from hashlib import blake2b import logging - -# from pathlib import Path from typing import ( Dict, Iterator, @@ -18,7 +16,10 @@ Sequence, Set, ) +from filelock import SoftFileLock import attrs.exceptions +from fileformats.core import FileSet +from . import user_cache_dir logger = logging.getLogger("pydra") @@ -52,19 +53,164 @@ ) Hash = NewType("Hash", bytes) -Cache = NewType("Cache", Dict[int, Hash]) +CacheKey = NewType("CacheKey", ty.Tuple[ty.Hashable, ty.Hashable]) + + +def location_converter(path: ty.Union[Path, str, None]) -> Path: + if path is None: + path = PersistentCache.location_default() + path = Path(path) + if not path.exists(): + path.mkdir(parents=True) + return path + + +@attrs.define +class PersistentCache: + """Persistent cache in which to store computationally expensive hashes between nodes + and workflow/task runs. It does this in via the `get_or_calculate_hash` method, which + takes a locally unique key (e.g. file-system path + mtime) and a function to + calculate the hash if it isn't present in the persistent store. + + The locally unique key is hashed (cheaply) using hashlib cryptography and this + "local hash" is use to name the entry of the (potentially expensive) hash of the + object itself (e.g. the contents of a file). This entry is saved as a text file + within a user-specific cache directory (see `platformdirs.user_cache_dir`), with + the name of the file being the "local hash" of the key and the contents of the + file being the "globally unique hash" of the object itself. + + Parameters + ---------- + location: Path + the directory in which to store the hashes cache + """ + + location: Path = attrs.field(converter=location_converter) # type: ignore[misc] + cleanup_period: int = attrs.field() + _hashes: ty.Dict[CacheKey, Hash] = attrs.field(factory=dict) + + # Set the location of the persistent hash cache + LOCATION_ENV_VAR = "PYDRA_HASH_CACHE" + CLEANUP_ENV_VAR = "PYDRA_HASH_CACHE_CLEANUP_PERIOD" + + @classmethod + def location_default(cls): + try: + location = os.environ[cls.LOCATION_ENV_VAR] + except KeyError: + location = user_cache_dir / "hashes" + return location + + # the default needs to be an instance method + @location.default + def _location_default(self): + return self.location_default() + + @location.validator + def location_validator(self, _, location): + if not os.path.isdir(location): + raise ValueError( + f"Persistent cache location '{location}' is not a directory" + ) + + @cleanup_period.default + def cleanup_period_default(self): + return int(os.environ.get(self.CLEANUP_ENV_VAR, 30)) + + def get_or_calculate_hash(self, key: CacheKey, calculate_hash: ty.Callable) -> Hash: + """Check whether key is present in the persistent cache store and return it if so. + Otherwise use `calculate_hash` to generate the hash and save it in the persistent + store. + + Parameters + ---------- + key : CacheKey + locally unique key (e.g. to the host) used to lookup the corresponding hash + in the persistent store + calculate_hash : ty.Callable + function to calculate the hash if it isn't present in the persistent store + + Returns + ------- + Hash + the hash corresponding to the key, which is either retrieved from the persistent + store or calculated using `calculate_hash` if not present + """ + try: + return self._hashes[key] + except KeyError: + pass + key_path = self.location / blake2b(str(key).encode()).hexdigest() + with SoftFileLock(key_path.with_suffix(".lock")): + if key_path.exists(): + return Hash(key_path.read_bytes()) + hsh = calculate_hash() + key_path.write_bytes(hsh) + self._hashes[key] = Hash(hsh) + return Hash(hsh) + + def clean_up(self): + """Cleans up old hash caches that haven't been accessed in the last 30 days""" + now = datetime.now() + for path in self.location.iterdir(): + if path.name.endswith(".lock"): + continue + days = (now - datetime.fromtimestamp(path.lstat().st_atime)).days + if days > self.cleanup_period: + path.unlink() + + @classmethod + def from_path( + cls, path: ty.Union[Path, str, "PersistentCache", None] + ) -> "PersistentCache": + if isinstance(path, PersistentCache): + return path + return PersistentCache(path) + + +@attrs.define +class Cache: + """Cache for hashing objects, used to avoid infinite recursion caused by circular + references between objects, and to store hashes of objects that have already been + hashed to avoid recomputation. + + This concept is extended to persistent caching of hashes for certain object types, + for which calculating the hash is a potentially expensive operation (e.g. + File/Directory types). For these classes the `bytes_repr` override function yields a + "locally unique cache key" (e.g. file-system path + mtime) as the first item of its + iterator. + """ + + persistent: ty.Optional[PersistentCache] = attrs.field( + default=None, + converter=PersistentCache.from_path, # type: ignore[misc] + ) + _hashes: ty.Dict[int, Hash] = attrs.field(factory=dict) + + def __getitem__(self, object_id: int) -> Hash: + return self._hashes[object_id] + + def __setitem__(self, object_id: int, hsh: Hash): + self._hashes[object_id] = hsh + + def __contains__(self, object_id): + return object_id in self._hashes class UnhashableError(ValueError): """Error for objects that cannot be hashed""" -def hash_function(obj, cache=None): +def hash_function(obj, **kwargs): """Generate hash of object.""" - return hash_object(obj, cache=cache).hex() + return hash_object(obj, **kwargs).hex() -def hash_object(obj: object, cache: ty.Optional[Cache] = None) -> Hash: +def hash_object( + obj: object, + cache: ty.Optional[Cache] = None, + persistent_cache: ty.Union[PersistentCache, Path, None] = None, +) -> Hash: """Hash an object Constructs a byte string that uniquely identifies the object, @@ -74,11 +220,11 @@ def hash_object(obj: object, cache: ty.Optional[Cache] = None) -> Hash: dicts. Custom types can be registered with :func:`register_serializer`. """ if cache is None: - cache = Cache({}) + cache = Cache(persistent=persistent_cache) try: return hash_single(obj, cache) except Exception as e: - raise UnhashableError(f"Cannot hash object {obj!r}") from e + raise UnhashableError(f"Cannot hash object {obj!r} due to '{e}'") from e def hash_single(obj: object, cache: Cache) -> Hash: @@ -91,11 +237,53 @@ def hash_single(obj: object, cache: Cache) -> Hash: if objid not in cache: # Handle recursion by putting a dummy value in the cache cache[objid] = Hash(b"\x00") - h = blake2b(digest_size=16, person=b"pydra-hash") - for chunk in bytes_repr(obj, cache): - h.update(chunk) - hsh = cache[objid] = Hash(h.digest()) + bytes_it = bytes_repr(obj, cache) + # Pop first element from the bytes_repr iterator and check whether it is a + # "local cache key" (e.g. file-system path + mtime tuple) or the first bytes + # chunk + + def calc_hash(first: ty.Optional[bytes] = None) -> Hash: + """ + Calculate the hash of the object + + Parameters + ---------- + first : ty.Optional[bytes] + the first bytes chunk from the bytes_repr iterator, passed if the first + chunk wasn't a local cache key + """ + h = blake2b(digest_size=16, person=b"pydra-hash") + # We want to use the first chunk that was popped to check for a cache-key + # if present + if first is not None: + h.update(first) + for chunk in bytes_it: # Note that `bytes_it` is in outer scope + h.update(chunk) + return Hash(h.digest()) + + # Read the first item of the bytes_repr iterator and check to see whether it yields + # a "cache-key" tuple instead of a bytes chunk for the type of the object to be cached + # (e.g. file-system path + mtime for fileformats.core.FileSet objects). If it + # does, use that key to check the persistent cache for a precomputed hash and + # return it if it is, otherwise calculate the hash and store it in the persistent + # cache with that hash of that key (not to be confused with the hash of the + # object that is saved/retrieved). + first = next(bytes_it) + if isinstance(first, tuple): + tp = type(obj) + key = ( + tp.__module__, + tp.__name__, + ) + first + hsh = cache.persistent.get_or_calculate_hash(key, calc_hash) + else: + # If the first item is a bytes chunk (i.e. the object type doesn't have an + # associated 'cache-key'), then simply calculate the hash of the object, + # passing the first chunk to the `calc_hash` function so it can be included + # in the hash calculation + hsh = calc_hash(first=first) logger.debug("Hash of %s object is %s", obj, hsh) + cache[objid] = hsh return cache[objid] @@ -276,6 +464,18 @@ def type_name(tp): yield b")" +@register_serializer(FileSet) +def bytes_repr_fileset( + fileset: FileSet, cache: Cache +) -> Iterator[ty.Union[CacheKey, bytes]]: + fspaths = sorted(fileset.fspaths) + yield CacheKey( + tuple(repr(p) for p in fspaths) # type: ignore[arg-type] + + tuple(p.lstat().st_mtime_ns for p in fspaths) + ) + yield from fileset.__bytes_repr__(cache) + + @register_serializer(list) @register_serializer(tuple) def bytes_repr_seq(obj: Sequence, cache: Cache) -> Iterator[bytes]: @@ -300,7 +500,7 @@ def bytes_repr_mapping_contents(mapping: Mapping, cache: Cache) -> Iterator[byte .. code-block:: python >>> from pydra.utils.hash import bytes_repr_mapping_contents, Cache - >>> generator = bytes_repr_mapping_contents({"a": 1, "b": 2}, Cache({})) + >>> generator = bytes_repr_mapping_contents({"a": 1, "b": 2}, Cache()) >>> b''.join(generator) b'str:1:a=...str:1:b=...' """ @@ -318,7 +518,7 @@ def bytes_repr_sequence_contents(seq: Sequence, cache: Cache) -> Iterator[bytes] .. code-block:: python >>> from pydra.utils.hash import bytes_repr_sequence_contents, Cache - >>> generator = bytes_repr_sequence_contents([1, 2], Cache({})) + >>> generator = bytes_repr_sequence_contents([1, 2], Cache()) >>> list(generator) [b'\x6d...', b'\xa3...'] """ @@ -339,39 +539,3 @@ def bytes_repr_numpy(obj: numpy.ndarray, cache: Cache) -> Iterator[bytes]: NUMPY_CHUNK_LEN = 8192 - - -# class MtimeCachingHash: -# """Hashing object that stores a cache of hash values for PathLikes - -# The cache only stores values for PathLikes pointing to existing files, -# and the mtime is checked to validate the cache. If the mtime differs, -# the old hash is discarded and a new mtime-tagged hash is stored. - -# The cache can grow without bound; we may want to consider using an LRU -# cache. -# """ - -# def __init__(self) -> None: -# self.cache: ty.Dict[os.PathLike, ty.Tuple[float, Hash]] = {} - -# def __call__(self, obj: object) -> Hash: -# if isinstance(obj, os.PathLike): -# path = Path(obj) -# try: -# stat_res = path.stat() -# mode, mtime = stat_res.st_mode, stat_res.st_mtime -# except FileNotFoundError: -# # Only attempt to cache existing files -# pass -# else: -# if stat.S_ISREG(mode) and obj in self.cache: -# # Cache (and hash) the actual object, as different pathlikes will have -# # different serializations -# save_mtime, save_hash = self.cache[obj] -# if mtime == save_mtime: -# return save_hash -# new_hash = hash_object(obj) -# self.cache[obj] = (mtime, new_hash) -# return new_hash -# return hash_object(obj) diff --git a/pydra/utils/tests/test_hash.py b/pydra/utils/tests/test_hash.py index 8da055e11..56a7d9e68 100644 --- a/pydra/utils/tests/test_hash.py +++ b/pydra/utils/tests/test_hash.py @@ -1,12 +1,22 @@ import re +import os from hashlib import blake2b from pathlib import Path - +import time +from unittest import mock import attrs import pytest import typing as ty from fileformats.application import Zip, Json -from ..hash import Cache, UnhashableError, bytes_repr, hash_object, register_serializer +from fileformats.text import TextFile +from ..hash import ( + Cache, + UnhashableError, + bytes_repr, + hash_object, + register_serializer, + PersistentCache, +) @pytest.fixture @@ -15,7 +25,7 @@ def hasher(): def join_bytes_repr(obj): - return b"".join(bytes_repr(obj, Cache({}))) + return b"".join(bytes_repr(obj, Cache())) def test_bytes_repr_builtins(): @@ -296,3 +306,83 @@ def _(obj: MyClass, cache: Cache): register_serializer(MyNewClass, _) assert join_bytes_repr(MyNewClass(1)) == b"serializer" + + +@pytest.fixture +def cache_path(tmp_path): + cache_path = tmp_path / "hash-cache" + cache_path.mkdir() + return cache_path + + +@pytest.fixture +def text_file(tmp_path): + text_file_path = tmp_path / "text-file.txt" + text_file_path.write_text("foo") + return TextFile(text_file_path) + + +def test_persistent_hash_cache(cache_path, text_file): + """ + Test the persistent hash cache with a text file + + The cache is used to store the hash of the text file, and the hash is + retrieved from the cache when the file is unchanged. + """ + # Test hash is stable between calls + hsh = hash_object(text_file, persistent_cache=cache_path) + assert hsh == hash_object(text_file, persistent_cache=cache_path) + + # Test that cached hash has been used by explicitly modifying it and seeing that the + # hash is the same as the modified hash + cache_files = list(cache_path.iterdir()) + assert len(cache_files) == 1 + modified_hash = "modified".encode() + cache_files[0].write_bytes(modified_hash) + assert hash_object(text_file, persistent_cache=cache_path) == modified_hash + + # Test that changes to the text file result in new hash + time.sleep(2) # Need to ensure that the mtimes will be different + text_file.fspath.write_text("bar") + assert hash_object(text_file, persistent_cache=cache_path) != modified_hash + assert len(list(cache_path.iterdir())) == 2 + + +def test_persistent_hash_cache_cleanup1(cache_path, text_file): + """ + Test the persistent hash is cleaned up after use if the periods between cleanups + is greater than the environment variable PYDRA_HASH_CACHE_CLEANUP_PERIOD + """ + with mock.patch.dict( + os.environ, + { + "PYDRA_HASH_CACHE": str(cache_path), + "PYDRA_HASH_CACHE_CLEANUP_PERIOD": "-100", + }, + ): + persistent_cache = PersistentCache() + hash_object(text_file, persistent_cache=persistent_cache) + assert len(list(cache_path.iterdir())) == 1 + persistent_cache.clean_up() + assert len(list(cache_path.iterdir())) == 0 + + +def test_persistent_hash_cache_cleanup2(cache_path, text_file): + """ + Test the persistent hash is cleaned up after use if the periods between cleanups + is greater than the explicitly provided cleanup_period + """ + persistent_cache = PersistentCache(cache_path, cleanup_period=-100) + hash_object(text_file, persistent_cache=persistent_cache) + assert len(list(cache_path.iterdir())) == 1 + time.sleep(2) + persistent_cache.clean_up() + assert len(list(cache_path.iterdir())) == 0 + + +def test_persistent_hash_cache_not_dir(text_file): + """ + Test that an error is raised if the provided cache path is not a directory + """ + with pytest.raises(ValueError, match="is not a directory"): + PersistentCache(text_file.fspath) diff --git a/pyproject.toml b/pyproject.toml index ad8f61ea8..6a6ad5e70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "filelock >=3.0.0", "fileformats >=0.8", "importlib_resources >=5.7; python_version < '3.11'", + "platformdirs >=2", "typing_extensions >=4.6.3; python_version < '3.10'", "typing_utils >=0.1.0; python_version < '3.10'", ]