Skip to content

Commit

Permalink
New mechanism to enable/disable resolver cache
Browse files Browse the repository at this point in the history
In particular the `env` resolver does not use the cache anymore.
  • Loading branch information
odelalleau committed Aug 12, 2020
1 parent 70d6f4b commit f7afe4b
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 14 deletions.
59 changes: 45 additions & 14 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def env(key: str, default: Any = _EMPTY_MARKER_, *, config: BaseContainer) -> An
val = visitor.visit(parse_tree)
return _get_value(val)

# Note that the `env` resolver does *NOT* use the cache.
OmegaConf.register_resolver(
"env", env, config_arg="config", variables_as_strings=False
"env", env, config_arg="config", variables_as_strings=False, use_cache=False,
)


Expand Down Expand Up @@ -333,6 +334,7 @@ def register_resolver(
variables_as_strings: bool = True,
config_arg: Optional[str] = None,
parent_arg: Optional[str] = None,
use_cache: Optional[bool] = None,
) -> None:
"""
The `variables_as_strings` flag was introduced to preserve backward compatibility
Expand All @@ -352,13 +354,37 @@ def register_resolver(
of `resolver` (of type `Optional[Container]`) to the parent of the key being
processed when the resolver is called. This can be useful for operations involving
other config options relative to the current key.
`use_cache` indicates whether the resolver's outputs should be cached. When not
provided, it defaults to `True` unless either `config_arg` or `parent_arg` is
used. In such situations it defaults to `False` and the user is warned to
explicitly set `use_cache=False` to make it clear that no caching is done
(currently caching is not supported when using `config_arg` or `parent_arg`).
"""
assert callable(resolver), "resolver must be callable"
# noinspection PyProtectedMember
assert (
name not in BaseContainer._resolvers
), "resolver {} is already registered".format(name)

if use_cache is None:
if config_arg is not None or parent_arg is not None:
warnings.warn(
f"You are using either `config_arg` or `parent_arg` to register "
f"resolver `{name}`: caching is not supported in such a case, and "
f"you must explicitly set `use_cache=False` to disable this warning.",
stacklevel=2,
)
use_cache = False
else:
use_cache = True
elif use_cache and (config_arg is not None or parent_arg is not None):
raise NotImplementedError(
f"Caching is not supported when using either `config_arg` or "
f"`parent_arg`, please set `use_cache=False` when registering "
f"resolver `{name}`",
)

def resolver_wrapper(
config: BaseContainer,
parent: Optional[Container],
Expand All @@ -381,19 +407,24 @@ def resolver_wrapper(
else:
inputs = key

cache = OmegaConf.get_cache(config)[name]
hashable_key = _make_hashable(key)
try:
val = cache[hashable_key]
except KeyError:
# Call resolver.
optional_args: Dict[str, Optional[Container]] = {}
if config_arg is not None:
optional_args[config_arg] = config
if parent_arg is not None:
optional_args[parent_arg] = parent
val = cache[hashable_key] = resolver(*inputs, **optional_args)
return val
if use_cache:
cache = OmegaConf.get_cache(config)[name]
hashable_key = _make_hashable(key)
try:
return cache[hashable_key]
except KeyError:
pass

# Call resolver.
optional_args: Dict[str, Optional[Container]] = {}
if config_arg is not None:
optional_args[config_arg] = config
if parent_arg is not None:
optional_args[parent_arg] = parent
ret = resolver(*inputs, **optional_args)
if use_cache:
cache[hashable_key] = ret
return ret

# noinspection PyProtectedMember
BaseContainer._resolvers[name] = resolver_wrapper
Expand Down
71 changes: 71 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ def test_env_default_interpolation_env_exist() -> None:
assert c.path == "/test/1234"


def test_env_is_not_cached() -> None:
os.environ["foobar"] = "1234"
c = OmegaConf.create({"foobar": "${env:foobar}"})
before = c.foobar
os.environ["foobar"] = "3456"
assert c.foobar != before


@pytest.mark.parametrize( # type: ignore
"value,expected",
[
Expand Down Expand Up @@ -238,6 +246,7 @@ def test_register_resolver_access_config(restore_resolvers: Any) -> None:
"len",
lambda value, *, root: len(OmegaConf.select(root, value)),
config_arg="root",
use_cache=False,
)
c = OmegaConf.create({"list": [1, 2, 3], "list_len": "${len:list}"})
assert c.list_len == 3
Expand All @@ -248,6 +257,7 @@ def test_register_resolver_access_parent(restore_resolvers: Any) -> None:
"get_sibling",
lambda sibling, *, parent: getattr(parent, sibling),
parent_arg="parent",
use_cache=False,
)
c = OmegaConf.create(
"""
Expand All @@ -261,6 +271,59 @@ def test_register_resolver_access_parent(restore_resolvers: Any) -> None:
assert c.root.foo.bar.baz1 == "useful data"


def test_register_resolver_access_parent_no_cache(restore_resolvers: Any) -> None:
OmegaConf.register_resolver(
"add_noise_to_sibling",
lambda sibling, *, parent: random.uniform(0, 1) + getattr(parent, sibling),
parent_arg="parent",
use_cache=False,
)
c = OmegaConf.create(
"""
root:
foo:
baz1: "${add_noise_to_sibling:baz2}"
baz2: 1
bar:
baz1: "${add_noise_to_sibling:baz2}"
baz2: 1
"""
)
assert c.root.foo.baz2 == c.root.bar.baz2 # make sure we test what we want to test
assert c.root.foo.baz1 != c.root.foo.baz1 # same node (regular "no cache" behavior)
assert c.root.foo.baz1 != c.root.bar.baz1 # same args but different parents


def test_register_resolver_cache_warnings(restore_resolvers: Any) -> None:
with pytest.warns(UserWarning):
OmegaConf.register_resolver(
"test_warning_parent", lambda *, parent: None, parent_arg="parent"
)

with pytest.warns(UserWarning):
OmegaConf.register_resolver(
"test_warning_config", lambda *, config: None, config_arg="config"
)


def test_register_resolver_cache_errors(restore_resolvers: Any) -> None:
with pytest.raises(NotImplementedError):
OmegaConf.register_resolver(
"test_error_parent",
lambda *, parent: None,
parent_arg="parent",
use_cache=True,
)

with pytest.raises(NotImplementedError):
OmegaConf.register_resolver(
"test_error_config",
lambda *, config: None,
config_arg="config",
use_cache=True,
)


def test_resolver_cache_1(restore_resolvers: Any) -> None:
# resolvers are always converted to stateless idempotent functions
# subsequent calls to the same function with the same argument will always return the same value.
Expand Down Expand Up @@ -311,6 +374,14 @@ def test_resolver_cache_3_dict_list(restore_resolvers: Any) -> None:
assert c.mixed1 != c.mixed2


def test_resolver_no_cache(restore_resolvers: Any) -> None:
OmegaConf.register_resolver(
"random", lambda _: random.uniform(0, 1), use_cache=False
)
c = OmegaConf.create(dict(k="${random:_}"))
assert c.k != c.k


@pytest.mark.parametrize( # type: ignore
"resolver,name,key,result",
[
Expand Down

0 comments on commit f7afe4b

Please sign in to comment.