Skip to content

Commit 78c07bf

Browse files
authored
FIX: make values in mapping hashable (#470)
* DX: write test for `make_hahsable()` * ENH: avoid returning tuple when one arg * ENH: embed function source code into hash * ENH: embed Python version in cache hash * FIX: call `_make_hashable_impl` in `_make_hashable_impl` * FIX: show warning before evaluating function
1 parent 731141f commit 78c07bf

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

src/ampform/sympy/_cache.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import annotations
1111

1212
import hashlib
13+
import inspect
1314
import logging
1415
import os
1516
import pickle # noqa: S403
@@ -94,7 +95,9 @@ def decorator(func: Callable[P, T]) -> Callable[P, T]:
9495
if "NO_CACHE" in os.environ:
9596
_warn_once("AmpForm cache disabled by NO_CACHE environment variable.")
9697
return func
98+
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
9799
function_identifier = f"{func.__module__}.{func.__name__}"
100+
src = inspect.getsource(func)
98101
dependency_identifiers = _get_dependency_identifiers(func, dependencies or [])
99102
nonlocal function_name
100103
if function_name is None:
@@ -103,16 +106,21 @@ def decorator(func: Callable[P, T]) -> Callable[P, T]:
103106
@wraps(func)
104107
def wrapped_function(*args: P.args, **kwargs: P.kwargs) -> T:
105108
hashable_object = make_hashable(
106-
function_identifier, *dependency_identifiers, args, kwargs
109+
function_identifier,
110+
src,
111+
python_version,
112+
*dependency_identifiers,
113+
args,
114+
kwargs,
107115
)
108116
h = get_readable_hash(hashable_object)
109117
cache_file = _get_cache_dir() / h[:2] / h[2:]
110118
if cache_file.exists():
111119
with open(cache_file, "rb") as f:
112120
return load_function(f)
113-
result = func(*args, **kwargs)
114121
msg = f"No cache file {cache_file}, performing {function_name}()..."
115122
_LOGGER.warning(msg)
123+
result = func(*args, **kwargs)
116124
cache_file.parent.mkdir(exist_ok=True, parents=True)
117125
with open(cache_file, "wb") as f:
118126
dump_function(result, f)
@@ -219,16 +227,27 @@ def to_bytes(obj) -> bytes:
219227

220228

221229
def make_hashable(*args) -> Hashable:
230+
"""Make a hashable object from any Python object.
231+
232+
>>> make_hashable("a", 1, {"b": 2}, {3, 4})
233+
('a', 1, frozendict.frozendict({'b': 2}), frozenset({3, 4}))
234+
>>> make_hashable({"a": {"sub-key": {1, 2, 3}, "b": [4, 5]}})
235+
frozendict.frozendict({'a': frozendict.frozendict({'sub-key': frozenset({1, 2, 3}), 'b': (4, 5)})})
236+
>>> make_hashable("already-hashable")
237+
'already-hashable'
238+
"""
239+
if len(args) == 1:
240+
return _make_hashable_impl(args[0])
222241
return tuple(_make_hashable_impl(x) for x in args)
223242

224243

225244
def _make_hashable_impl(obj) -> Hashable:
226245
if isinstance(obj, abc.Mapping):
227-
return frozendict(obj)
246+
return frozendict({k: _make_hashable_impl(v) for k, v in obj.items()})
228247
if isinstance(obj, str):
229248
return obj
230249
if isinstance(obj, abc.Iterable):
231-
hashable_items = (make_hashable(x) for x in obj)
250+
hashable_items = (_make_hashable_impl(x) for x in obj)
232251
if isinstance(obj, abc.Sequence):
233252
return tuple(hashable_items)
234253
if isinstance(obj, set):

0 commit comments

Comments
 (0)