Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: First-class caching #1104

Merged
merged 26 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
dbe5a3c
added basic fingerprinting support
Jul 18, 2024
998ca73
fixed tests and documentation build
Sep 25, 2024
d6cbcb6
fixed pre-commits
Sep 25, 2024
dc4a75c
fixed NoneType for <3.10
Sep 25, 2024
b968fba
fix DISABLE behavior and handle unhashable values
Sep 26, 2024
95d5e4f
fixed DISABLE behavior; fixed cache key name collision with extract_f…
Sep 26, 2024
1c2873f
fixed cache_key signature in test
Sep 26, 2024
9be419b
added caching tutorial notebook
Sep 27, 2024
ea74a27
fixed materializer bug
Sep 27, 2024
b68775d
rename kwarg to dependencies_data_versions
Sep 30, 2024
73cdf61
DEFAULT behavior can be specified at the Driver
Sep 30, 2024
ccc6c01
changed default cache path
Sep 30, 2024
82139db
renamed drivers and modules with descriptive names
Sep 30, 2024
963e5fc
updated tutorial notebook with feedback
Sep 30, 2024
6b0df62
renamed to HamiltonCacheAdapter, remove materializer RECOMPUTE behavior
Oct 1, 2024
571e4f7
fix default parameter and materializer; handle stdlib types
Oct 2, 2024
c4ffb9f
updated tutorial notebook; added materializer tutorial
Oct 2, 2024
eaca75c
fixed tests for dataloader tag change
Oct 2, 2024
39df87b
updated docs; added notebook rendering support
Oct 2, 2024
1a64444
fixed singledispatch type annotation for older Python versions
Oct 2, 2024
1356773
cleaned up docs; fixed type 3.8; remove docs dep
Oct 3, 2024
89388ea
fix docs config
Oct 3, 2024
eb6e277
remove unreliable test for 3.8
Oct 3, 2024
e27ee53
fixed recursion error for unordered maps
Oct 3, 2024
e778d51
add mention about cache keys stability
Oct 3, 2024
75cc9f3
fixed max_depth in recursive hashing test
Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 134 additions & 9 deletions hamilton/caching/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
S = TypeVar("S", object, object)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the value of this generic type variable? We just have it in outputs, so I'm not sure it'll help with code completion.

Copy link
Collaborator Author

@zilto zilto Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having -> Union[str, S] improves readability by specifying that the function can return the SENTINEL value. It's much less ambiguous than not using a TypeVar such as -> Union[str, object].

It was your recommendation to use a SENTINEL value and it did help clarify the code and catch errors.

This doesn't have anything to do with code completion and I don't think it needs to?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO a typevar really shouldn't be used if it's not linked across two places. E.G. if this class were implemented using a typevar as a generic, then we have value from it. But I guess this is fine -- pylance isn't complaining (although it does if it's only used in one place).

Specifically, type-vars are best (IMO) when cascaded through different parts of the code. E.G. if you pass in an str, you get out an str. Hence what I mean by code completion. In this case it's really just a union type, and not a variable.

Better (IMO) would be S=Union[str, object]. But I think it's clearer to have Union[str, object] everywhere.

Pylance doesn't like it because it can't actually parse that. So this is fine for now.

Funnily enough, Union[str, object] is kind of silly, as a str is an object.

Anyway, not a huge deal, but these are my thoughts.



CACHING_BEHAVIORS = Literal["default", "recompute", "disable", "ignore"]


class CachingBehavior(enum.Enum):
"""Behavior applied by the caching adapter

Expand Down Expand Up @@ -125,8 +128,7 @@ class CachingEventType(enum.Enum):
IS_OVERRIDE = "is_override"
IS_INPUT = "is_input"
IS_FINAL_VAR = "is_final_var"
APPLY_DATA_SAVER = "apply_data_saver"
APPLY_DATA_LOADER = "apply_data_loader"
IS_DEFAULT_PARAMETER_VALUE = "is_default_parameter_value"


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -200,8 +202,11 @@ def __init__(
result_store: Optional[ResultStore] = None,
default: Optional[Union[Literal[True], Collection[str]]] = None,
recompute: Optional[Union[Literal[True], Collection[str]]] = None,
disable: Optional[Union[Literal[True], Collection[str]]] = None,
ignore: Optional[Union[Literal[True], Collection[str]]] = None,
disable: Optional[Union[Literal[True], Collection[str]]] = None,
default_behavior: Optional[CACHING_BEHAVIORS] = None,
default_loader_behavior: Optional[CACHING_BEHAVIORS] = None,
default_saver_behavior: Optional[CACHING_BEHAVIORS] = None,
log_to_file: bool = False,
**kwargs,
):
Expand All @@ -212,8 +217,11 @@ def __init__(
:param result_store: BaseStore caching dataflow execution results
:param default: Set caching behavior to DEFAULT for specified node names. If True, apply to all nodes.
:param recompute: Set caching behavior to RECOMPUTE for specified node names. If True, apply to all nodes.
:param disable: Set caching behavior to DISABLE for specified node names. If True, apply to all nodes.
:param ignore: Set caching behavior to IGNORE for specified node names. If True, apply to all nodes.
:param disable: Set caching behavior to DISABLE for specified node names. If True, apply to all nodes.
:param default_behavior: Set the default caching behavior.
:param default_loader_behavior: Set the default caching behavior `DataLoader` nodes.
:param default_saver_behavior: Set the default caching behavior `DataSaver` nodes.
:param log_to_file: If True, append cache event logs as they happen in JSONL format.
"""
self._path = path
Expand All @@ -233,10 +241,15 @@ def __init__(
self._recompute = recompute
self._disable = disable
self._ignore = ignore
self.default_behavior = default_behavior
self.default_loader_behavior = default_loader_behavior
self.default_saver_behavior = default_saver_behavior

# attributes populated at execution time
self.run_ids: List[str] = []
self._fn_graphs: Dict[str, FunctionGraph] = {} # {run_id: graph}
elijahbenizzy marked this conversation as resolved.
Show resolved Hide resolved
self._data_savers: Dict[str, Collection[str]] = {} # {run_id: list[node_name]}
self._data_loaders: Dict[str, Collection[str]] = {} # {run_id: list[node_name]}
self.behaviors: Dict[
str, Dict[str, CachingBehavior]
] = {} # {run_id: {node_name: behavior}}
Expand Down Expand Up @@ -828,6 +841,9 @@ def _resolve_node_behavior(
disable: Optional[Collection[str]] = None,
recompute: Optional[Collection[str]] = None,
ignore: Optional[Collection[str]] = None,
default_behavior: CACHING_BEHAVIORS = "default",
default_loader_behavior: CACHING_BEHAVIORS = "default",
default_saver_behavior: CACHING_BEHAVIORS = "default",
) -> CachingBehavior:
"""Determine the cache behavior of a node.
Behavior specified via the ``Builder`` has precedence over the ``@cache`` decorator.
Expand Down Expand Up @@ -856,16 +872,20 @@ def _resolve_node_behavior(
if node.name in node_set:
if behavior_from_driver is not SENTINEL:
raise ValueError(
f"Multiple caching behaviors specifiself.resolve_behaviors(run_id=run_id, graph=graph)ed by Driver for node: {node.name}"
f"Multiple caching behaviors specified by Driver for node: {node.name}"
)
behavior_from_driver = behavior

if behavior_from_driver is not SENTINEL:
return behavior_from_driver
elif behavior_from_tag is not SENTINEL:
return behavior_from_tag
elif node.tags.get("hamilton.data_loader"):
return CachingBehavior.from_string(default_loader_behavior)
elif node.tags.get("hamilton.data_saver"):
return CachingBehavior.from_string(default_saver_behavior)
else:
return CachingBehavior.DEFAULT
return CachingBehavior.from_string(default_behavior)

def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]:
"""Resolve the caching behavior for each node based on the ``@cache`` decorator
Expand Down Expand Up @@ -898,6 +918,18 @@ def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]:
elif _ignore is True:
_ignore = [n.name for n in graph.get_nodes()]

default_behavior = "default"
if self.default_behavior is not None:
default_behavior = self.default_behavior

default_loader_behavior = default_behavior
if self.default_loader_behavior is not None:
default_loader_behavior = self.default_loader_behavior

default_saver_behavior = default_behavior
if self.default_saver_behavior is not None:
default_saver_behavior = self.default_saver_behavior

behaviors = {}
for node in graph.get_nodes():
behavior = HamiltonCacheAdapter._resolve_node_behavior(
Expand All @@ -906,6 +938,9 @@ def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]:
disable=_disable,
recompute=_recompute,
ignore=_ignore,
default_behavior=default_behavior,
default_loader_behavior=default_loader_behavior,
default_saver_behavior=default_saver_behavior,
)
behaviors[node.name] = behavior

Expand All @@ -917,6 +952,41 @@ def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]:
event_type=CachingEventType.RESOLVE_BEHAVIOR,
value=behavior,
)

# need to handle materializers via a second pass to copy the behavior
# of their "main node"
for node in graph.get_nodes():
if node.tags.get("hamilton.data_loader") is True:
main_node = node.tags["hamilton.data_loader.node"]
if main_node == node.name:
continue

# solution for `@dataloader` and `from_`
if behaviors.get(main_node, None) is not None:
behaviors[node.name] = behaviors[main_node]
# this hacky section is required to support @load_from and provide
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure I follow what's happening here. Let's add some more docs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, document desired behavior here/how the code maps to it

# a unified pattern to specify behavior from the module or the driver
else:
behaviors[node.name] = HamiltonCacheAdapter._resolve_node_behavior(
# we create a fake node, only its name matters
node=hamilton.node.Node(
name=main_node,
typ=str,
callabl=lambda: None,
tags=node.tags.copy(),
),
default=_default,
disable=_disable,
recompute=_recompute,
ignore=_ignore,
default_behavior=default_loader_behavior,
)

self._data_loaders[run_id].append(main_node)

if node.tags.get("hamilton.data_saver", None) is not None:
self._data_savers[run_id].append(node.name)

return behaviors

def resolve_code_versions(
Expand Down Expand Up @@ -992,6 +1062,24 @@ def _process_override(self, run_id: str, node_name: str, value: Any) -> None:
value=data_version,
)

@staticmethod
def _resolve_default_parameter_values(
node_: hamilton.node.Node, node_kwargs: dict[str, Any]
) -> dict[str, Any]:
"""
If a node uses the function's default parameter values, they won't be part of the
node_kwargs. To ensure a consistent `cache_key` we want to retrieve default parameter
values if they're used
"""
resolved_kwargs = node_kwargs.copy()
for param_name, param_value in node_.default_parameter_values.items():
# if the `param_name` not in `node_kwargs`, it means the node uses the default
# parameter value
if param_name not in node_kwargs.keys():
resolved_kwargs.update(**{param_name: param_value})

return resolved_kwargs

def pre_graph_execute(
self,
*,
Expand All @@ -1016,6 +1104,10 @@ def pre_graph_execute(
self.code_versions[run_id] = self.resolve_code_versions(
run_id=run_id, final_vars=final_vars, inputs=inputs, overrides=overrides
)
# the empty `._data_loaders` and `._data_savers` need to be instantiated before calling
# `self.resolve_behaviors` because it appends to them
self._data_loaders[run_id] = []
self._data_savers[run_id] = []
self.behaviors[run_id] = self.resolve_behaviors(run_id=run_id)

# final vars are logged to be retrieved by the ``.view_run()`` method
Expand Down Expand Up @@ -1060,7 +1152,7 @@ def pre_node_execute(

"""
node_name = node_.name
node_kwargs = kwargs
node_kwargs = HamiltonCacheAdapter._resolve_default_parameter_values(node_, kwargs)

if self.behaviors[run_id][node_name] == CachingBehavior.IGNORE:
return
Expand Down Expand Up @@ -1154,7 +1246,7 @@ def do_node_execute(
"""
node_name = node_.name
node_callable = node_.callable
node_kwargs = kwargs
node_kwargs = HamiltonCacheAdapter._resolve_default_parameter_values(node_, kwargs)

if self.behaviors[run_id][node_name] in (
CachingBehavior.DISABLE,
Expand All @@ -1173,7 +1265,23 @@ def do_node_execute(
CachingBehavior.IGNORE,
):
cache_key = self.get_cache_key(run_id=run_id, node_name=node_name, task_id=task_id)

# nodes collected in `._data_loaders` return tuples of (result, metadata)
# where metadata often includes a timestamp. To ensure we provide a consistent
# `data_version` / hash, we only hash the result part of the materializer return
# value and discard the metadata.
if node_name in self._data_loaders[run_id] and isinstance(result, tuple):
result = result[0]

data_version = self._version_data(node_name=node_name, run_id=run_id, result=result)

# nodes collected in `._data_savers` return a dictionary of metadata
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced we need to special case these... But for now, let's call this a bit experimental (data savers/loaders) -- it's a very odd case that we don't want ot dwell onmore.

# this metadata often includes a timestamp, leading to an unstable hash.
# we do not version nor store the metadata. This node is executed for its
# external effect of saving a file
if node_name in self._data_savers[run_id]:
data_version = f"{node_name}__metadata"

self._set_memory_metadata(
run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version
)
Expand Down Expand Up @@ -1244,7 +1352,23 @@ def do_node_execute(
node_kwargs=node_kwargs,
task_id=task_id,
)

# nodes collected in `._data_loaders` return tuples of (result, metadata)
# where metadata often includes a timestamp. To ensure we provide a consistent
# `data_version` / hash, we only hash the result part of the materializer return
# value and discard the metadata.
if node_name in self._data_loaders[run_id] and isinstance(result, tuple):
result = result[0]

data_version = self._version_data(node_name=node_name, run_id=run_id, result=result)

# nodes collected in `._data_savers` return a dictionary of metadata
# this metadata often includes a timestamp, leading to an unstable hash.
# we do not version nor store the metadata. This node is executed for its
# external effect of saving a file
if node_name in self._data_savers[run_id]:
data_version = f"{node_name}__metadata"

self._set_memory_metadata(
run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version
)
Expand Down Expand Up @@ -1298,7 +1422,8 @@ def post_node_execute(
assert data_version is not SENTINEL

# TODO clean up this logic
# check for materialized file when using `@cache(format="json")`
# check if a materialized file exist before writing results
# when using `@cache(format="json")`
cache_format = (
self._fn_graphs[run_id]
.nodes[node_name]
Expand Down
Loading