-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: First-class caching #1104
Changes from 1 commit
dbe5a3c
998ca73
d6cbcb6
dc4a75c
b968fba
95d5e4f
1c2873f
9be419b
ea74a27
b68775d
73cdf61
ccc6c01
82139db
963e5fc
6b0df62
571e4f7
c4ffb9f
eaca75c
39df87b
1a64444
1356773
89388ea
eb6e277
e27ee53
e778d51
75cc9f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,9 @@ | |
S = TypeVar("S", object, object) | ||
|
||
|
||
CACHING_BEHAVIORS = Literal["default", "recompute", "disable", "ignore"] | ||
|
||
|
||
class CachingBehavior(enum.Enum): | ||
"""Behavior applied by the caching adapter | ||
|
||
|
@@ -125,8 +128,7 @@ class CachingEventType(enum.Enum): | |
IS_OVERRIDE = "is_override" | ||
IS_INPUT = "is_input" | ||
IS_FINAL_VAR = "is_final_var" | ||
APPLY_DATA_SAVER = "apply_data_saver" | ||
APPLY_DATA_LOADER = "apply_data_loader" | ||
IS_DEFAULT_PARAMETER_VALUE = "is_default_parameter_value" | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
|
@@ -200,8 +202,11 @@ def __init__( | |
result_store: Optional[ResultStore] = None, | ||
default: Optional[Union[Literal[True], Collection[str]]] = None, | ||
recompute: Optional[Union[Literal[True], Collection[str]]] = None, | ||
disable: Optional[Union[Literal[True], Collection[str]]] = None, | ||
ignore: Optional[Union[Literal[True], Collection[str]]] = None, | ||
disable: Optional[Union[Literal[True], Collection[str]]] = None, | ||
default_behavior: Optional[CACHING_BEHAVIORS] = None, | ||
default_loader_behavior: Optional[CACHING_BEHAVIORS] = None, | ||
default_saver_behavior: Optional[CACHING_BEHAVIORS] = None, | ||
log_to_file: bool = False, | ||
**kwargs, | ||
): | ||
|
@@ -212,8 +217,11 @@ def __init__( | |
:param result_store: BaseStore caching dataflow execution results | ||
:param default: Set caching behavior to DEFAULT for specified node names. If True, apply to all nodes. | ||
:param recompute: Set caching behavior to RECOMPUTE for specified node names. If True, apply to all nodes. | ||
:param disable: Set caching behavior to DISABLE for specified node names. If True, apply to all nodes. | ||
:param ignore: Set caching behavior to IGNORE for specified node names. If True, apply to all nodes. | ||
:param disable: Set caching behavior to DISABLE for specified node names. If True, apply to all nodes. | ||
:param default_behavior: Set the default caching behavior. | ||
:param default_loader_behavior: Set the default caching behavior `DataLoader` nodes. | ||
:param default_saver_behavior: Set the default caching behavior `DataSaver` nodes. | ||
:param log_to_file: If True, append cache event logs as they happen in JSONL format. | ||
""" | ||
self._path = path | ||
|
@@ -233,10 +241,15 @@ def __init__( | |
self._recompute = recompute | ||
self._disable = disable | ||
self._ignore = ignore | ||
self.default_behavior = default_behavior | ||
self.default_loader_behavior = default_loader_behavior | ||
self.default_saver_behavior = default_saver_behavior | ||
|
||
# attributes populated at execution time | ||
self.run_ids: List[str] = [] | ||
self._fn_graphs: Dict[str, FunctionGraph] = {} # {run_id: graph} | ||
elijahbenizzy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._data_savers: Dict[str, Collection[str]] = {} # {run_id: list[node_name]} | ||
self._data_loaders: Dict[str, Collection[str]] = {} # {run_id: list[node_name]} | ||
self.behaviors: Dict[ | ||
str, Dict[str, CachingBehavior] | ||
] = {} # {run_id: {node_name: behavior}} | ||
|
@@ -828,6 +841,9 @@ def _resolve_node_behavior( | |
disable: Optional[Collection[str]] = None, | ||
recompute: Optional[Collection[str]] = None, | ||
ignore: Optional[Collection[str]] = None, | ||
default_behavior: CACHING_BEHAVIORS = "default", | ||
default_loader_behavior: CACHING_BEHAVIORS = "default", | ||
default_saver_behavior: CACHING_BEHAVIORS = "default", | ||
) -> CachingBehavior: | ||
"""Determine the cache behavior of a node. | ||
Behavior specified via the ``Builder`` has precedence over the ``@cache`` decorator. | ||
|
@@ -856,16 +872,20 @@ def _resolve_node_behavior( | |
if node.name in node_set: | ||
if behavior_from_driver is not SENTINEL: | ||
raise ValueError( | ||
f"Multiple caching behaviors specifiself.resolve_behaviors(run_id=run_id, graph=graph)ed by Driver for node: {node.name}" | ||
f"Multiple caching behaviors specified by Driver for node: {node.name}" | ||
) | ||
behavior_from_driver = behavior | ||
|
||
if behavior_from_driver is not SENTINEL: | ||
return behavior_from_driver | ||
elif behavior_from_tag is not SENTINEL: | ||
return behavior_from_tag | ||
elif node.tags.get("hamilton.data_loader"): | ||
return CachingBehavior.from_string(default_loader_behavior) | ||
elif node.tags.get("hamilton.data_saver"): | ||
return CachingBehavior.from_string(default_saver_behavior) | ||
else: | ||
return CachingBehavior.DEFAULT | ||
return CachingBehavior.from_string(default_behavior) | ||
|
||
def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]: | ||
"""Resolve the caching behavior for each node based on the ``@cache`` decorator | ||
|
@@ -898,6 +918,18 @@ def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]: | |
elif _ignore is True: | ||
_ignore = [n.name for n in graph.get_nodes()] | ||
|
||
default_behavior = "default" | ||
if self.default_behavior is not None: | ||
default_behavior = self.default_behavior | ||
|
||
default_loader_behavior = default_behavior | ||
if self.default_loader_behavior is not None: | ||
default_loader_behavior = self.default_loader_behavior | ||
|
||
default_saver_behavior = default_behavior | ||
if self.default_saver_behavior is not None: | ||
default_saver_behavior = self.default_saver_behavior | ||
|
||
behaviors = {} | ||
for node in graph.get_nodes(): | ||
behavior = HamiltonCacheAdapter._resolve_node_behavior( | ||
|
@@ -906,6 +938,9 @@ def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]: | |
disable=_disable, | ||
recompute=_recompute, | ||
ignore=_ignore, | ||
default_behavior=default_behavior, | ||
default_loader_behavior=default_loader_behavior, | ||
default_saver_behavior=default_saver_behavior, | ||
) | ||
behaviors[node.name] = behavior | ||
|
||
|
@@ -917,6 +952,41 @@ def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]: | |
event_type=CachingEventType.RESOLVE_BEHAVIOR, | ||
value=behavior, | ||
) | ||
|
||
# need to handle materializers via a second pass to copy the behavior | ||
# of their "main node" | ||
for node in graph.get_nodes(): | ||
if node.tags.get("hamilton.data_loader") is True: | ||
main_node = node.tags["hamilton.data_loader.node"] | ||
if main_node == node.name: | ||
continue | ||
|
||
# solution for `@dataloader` and `from_` | ||
if behaviors.get(main_node, None) is not None: | ||
behaviors[node.name] = behaviors[main_node] | ||
# this hacky section is required to support @load_from and provide | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, not sure I follow what's happening here. Let's add some more docs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specifically, document desired behavior here/how the code maps to it |
||
# a unified pattern to specify behavior from the module or the driver | ||
else: | ||
behaviors[node.name] = HamiltonCacheAdapter._resolve_node_behavior( | ||
# we create a fake node, only its name matters | ||
node=hamilton.node.Node( | ||
name=main_node, | ||
typ=str, | ||
callabl=lambda: None, | ||
tags=node.tags.copy(), | ||
), | ||
default=_default, | ||
disable=_disable, | ||
recompute=_recompute, | ||
ignore=_ignore, | ||
default_behavior=default_loader_behavior, | ||
) | ||
|
||
self._data_loaders[run_id].append(main_node) | ||
|
||
if node.tags.get("hamilton.data_saver", None) is not None: | ||
self._data_savers[run_id].append(node.name) | ||
|
||
return behaviors | ||
|
||
def resolve_code_versions( | ||
|
@@ -992,6 +1062,24 @@ def _process_override(self, run_id: str, node_name: str, value: Any) -> None: | |
value=data_version, | ||
) | ||
|
||
@staticmethod | ||
def _resolve_default_parameter_values( | ||
node_: hamilton.node.Node, node_kwargs: dict[str, Any] | ||
) -> dict[str, Any]: | ||
""" | ||
If a node uses the function's default parameter values, they won't be part of the | ||
node_kwargs. To ensure a consistent `cache_key` we want to retrieve default parameter | ||
values if they're used | ||
""" | ||
resolved_kwargs = node_kwargs.copy() | ||
for param_name, param_value in node_.default_parameter_values.items(): | ||
# if the `param_name` not in `node_kwargs`, it means the node uses the default | ||
# parameter value | ||
if param_name not in node_kwargs.keys(): | ||
resolved_kwargs.update(**{param_name: param_value}) | ||
|
||
return resolved_kwargs | ||
|
||
def pre_graph_execute( | ||
self, | ||
*, | ||
|
@@ -1016,6 +1104,10 @@ def pre_graph_execute( | |
self.code_versions[run_id] = self.resolve_code_versions( | ||
run_id=run_id, final_vars=final_vars, inputs=inputs, overrides=overrides | ||
) | ||
# the empty `._data_loaders` and `._data_savers` need to be instantiated before calling | ||
# `self.resolve_behaviors` because it appends to them | ||
self._data_loaders[run_id] = [] | ||
self._data_savers[run_id] = [] | ||
self.behaviors[run_id] = self.resolve_behaviors(run_id=run_id) | ||
|
||
# final vars are logged to be retrieved by the ``.view_run()`` method | ||
|
@@ -1060,7 +1152,7 @@ def pre_node_execute( | |
|
||
""" | ||
node_name = node_.name | ||
node_kwargs = kwargs | ||
node_kwargs = HamiltonCacheAdapter._resolve_default_parameter_values(node_, kwargs) | ||
|
||
if self.behaviors[run_id][node_name] == CachingBehavior.IGNORE: | ||
return | ||
|
@@ -1154,7 +1246,7 @@ def do_node_execute( | |
""" | ||
node_name = node_.name | ||
node_callable = node_.callable | ||
node_kwargs = kwargs | ||
node_kwargs = HamiltonCacheAdapter._resolve_default_parameter_values(node_, kwargs) | ||
|
||
if self.behaviors[run_id][node_name] in ( | ||
CachingBehavior.DISABLE, | ||
|
@@ -1173,7 +1265,23 @@ def do_node_execute( | |
CachingBehavior.IGNORE, | ||
): | ||
cache_key = self.get_cache_key(run_id=run_id, node_name=node_name, task_id=task_id) | ||
|
||
# nodes collected in `._data_loaders` return tuples of (result, metadata) | ||
# where metadata often includes a timestamp. To ensure we provide a consistent | ||
# `data_version` / hash, we only hash the result part of the materializer return | ||
# value and discard the metadata. | ||
if node_name in self._data_loaders[run_id] and isinstance(result, tuple): | ||
result = result[0] | ||
|
||
data_version = self._version_data(node_name=node_name, run_id=run_id, result=result) | ||
|
||
# nodes collected in `._data_savers` return a dictionary of metadata | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not convinced we need to special case these... But for now, let's call this a bit experimental (data savers/loaders) -- it's a very odd case that we don't want ot dwell onmore. |
||
# this metadata often includes a timestamp, leading to an unstable hash. | ||
# we do not version nor store the metadata. This node is executed for its | ||
# external effect of saving a file | ||
if node_name in self._data_savers[run_id]: | ||
data_version = f"{node_name}__metadata" | ||
|
||
self._set_memory_metadata( | ||
run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version | ||
) | ||
|
@@ -1244,7 +1352,23 @@ def do_node_execute( | |
node_kwargs=node_kwargs, | ||
task_id=task_id, | ||
) | ||
|
||
# nodes collected in `._data_loaders` return tuples of (result, metadata) | ||
# where metadata often includes a timestamp. To ensure we provide a consistent | ||
# `data_version` / hash, we only hash the result part of the materializer return | ||
# value and discard the metadata. | ||
if node_name in self._data_loaders[run_id] and isinstance(result, tuple): | ||
result = result[0] | ||
|
||
data_version = self._version_data(node_name=node_name, run_id=run_id, result=result) | ||
|
||
# nodes collected in `._data_savers` return a dictionary of metadata | ||
# this metadata often includes a timestamp, leading to an unstable hash. | ||
# we do not version nor store the metadata. This node is executed for its | ||
# external effect of saving a file | ||
if node_name in self._data_savers[run_id]: | ||
data_version = f"{node_name}__metadata" | ||
|
||
self._set_memory_metadata( | ||
run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version | ||
) | ||
|
@@ -1298,7 +1422,8 @@ def post_node_execute( | |
assert data_version is not SENTINEL | ||
|
||
# TODO clean up this logic | ||
# check for materialized file when using `@cache(format="json")` | ||
# check if a materialized file exist before writing results | ||
# when using `@cache(format="json")` | ||
cache_format = ( | ||
self._fn_graphs[run_id] | ||
.nodes[node_name] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the value of this generic type variable? We just have it in outputs, so I'm not sure it'll help with code completion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having
-> Union[str, S]
improves readability by specifying that the function can return theSENTINEL
value. It's much less ambiguous than not using aTypeVar
such as-> Union[str, object]
.It was your recommendation to use a
SENTINEL
value and it did help clarify the code and catch errors.This doesn't have anything to do with code completion and I don't think it needs to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO a typevar really shouldn't be used if it's not linked across two places. E.G. if this class were implemented using a typevar as a generic, then we have value from it. But I guess this is fine -- pylance isn't complaining (although it does if it's only used in one place).
Specifically, type-vars are best (IMO) when cascaded through different parts of the code. E.G. if you pass in an str, you get out an str. Hence what I mean by code completion. In this case it's really just a union type, and not a variable.
Better (IMO) would be
S=Union[str, object]
. But I think it's clearer to haveUnion[str, object]
everywhere.Pylance doesn't like it because it can't actually parse that. So this is fine for now.
Funnily enough,
Union[str, object]
is kind of silly, as astr
is an object.Anyway, not a huge deal, but these are my thoughts.