From a6b389808095f7601c9614b7da7a6347eef243ec Mon Sep 17 00:00:00 2001 From: Thierry Jean <68975210+zilto@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:30:04 -0400 Subject: [PATCH] feat: First-class caching (#1104) Caching is now a core Hamilton feature. You can add caching by using `Builder().with_cache()` and the cache is accessible via the `Driver.cache` attribute. The feature is implemented using lifecycle hooks, but this could change in Hamilton 2.0. This also introduces the concept of `MetadataStore` and `ResultStore` which could be further reused with the Hamilton UI or in specific deployment scenarios. This also includes a bug fix for the `@dataloader` decorator which had accidentally flipped two tags compared to `@load_from` and `from_.`. The fix includes changes to `@dataloader`, a related test, and to the visualization function in `graph.create_graphviz_graph`. The reference documentation page `caching-logic.rst` includes many TODOs, quirks, limitations, and roadmap items. Eventually, we will need to deprecate the other caching mechanisms. --------- Co-authored-by: zilto --- .gitignore | 3 + docs/concepts/_caching/view_run_example.svg | 167 + docs/concepts/builder.rst | 18 + docs/concepts/caching.rst | 575 +++ docs/concepts/index.rst | 1 + docs/conf.py | 4 +- docs/how-tos/cache-nodes.rst | 9 - docs/how-tos/caching-tutorial.ipynb | 4018 ++++++++++++++++ docs/how-tos/index.rst | 2 +- docs/index.md | 1 + docs/reference/caching/caching-logic.rst | 66 + docs/reference/caching/data-versioning.rst | 6 + docs/reference/caching/index.rst | 13 + docs/reference/caching/stores.rst | 21 + examples/caching/README.md | 6 + examples/caching/materializer_tutorial.ipynb | 2842 ++++++++++++ examples/caching/raw_data.parquet | Bin 0 -> 4294 bytes examples/caching/requirements.txt | 3 + examples/caching/tutorial.ipynb | 4054 +++++++++++++++++ examples/caching_nodes/caching.ipynb | 375 ++ hamilton/caching/__init__.py | 0 hamilton/caching/adapter.py | 1457 ++++++ hamilton/caching/cache_key.py | 53 + hamilton/caching/fingerprinting.py | 267 ++ hamilton/caching/stores/__init__.py | 0 hamilton/caching/stores/base.py | 220 + hamilton/caching/stores/file.py | 89 + hamilton/caching/stores/sqlite.py | 204 + hamilton/caching/stores/utils.py | 21 + hamilton/driver.py | 118 +- hamilton/experimental/h_cache.py | 10 + hamilton/experimental/h_databackends.py | 5 + hamilton/function_modifiers/__init__.py | 1 + hamilton/function_modifiers/adapters.py | 8 +- hamilton/function_modifiers/metadata.py | 101 +- hamilton/graph.py | 2 +- hamilton/lifecycle/default.py | 7 + hamilton/plugins/h_diskcache.py | 7 + pyproject.toml | 2 +- tests/caching/__init__.py | 0 tests/caching/test_adapter.py | 150 + tests/caching/test_fingerprinting.py | 186 + tests/caching/test_integration.py | 619 +++ tests/caching/test_metadata_store.py | 100 + tests/caching/test_result_store.py | 116 + tests/function_modifiers/test_adapters.py | 4 +- .../parallelism_with_caching.py | 16 + tests/test_hamilton_driver.py | 18 + 48 files changed, 15943 insertions(+), 22 deletions(-) create mode 100644 docs/concepts/_caching/view_run_example.svg create mode 100644 docs/concepts/caching.rst delete mode 100644 docs/how-tos/cache-nodes.rst create mode 100644 docs/how-tos/caching-tutorial.ipynb create mode 100644 docs/reference/caching/caching-logic.rst create mode 100644 docs/reference/caching/data-versioning.rst create mode 100644 docs/reference/caching/index.rst create mode 100644 docs/reference/caching/stores.rst create mode 100644 examples/caching/README.md create mode 100644 examples/caching/materializer_tutorial.ipynb create mode 100644 examples/caching/raw_data.parquet create mode 100644 examples/caching/requirements.txt create mode 100644 examples/caching/tutorial.ipynb create mode 100644 examples/caching_nodes/caching.ipynb create mode 100644 hamilton/caching/__init__.py create mode 100644 hamilton/caching/adapter.py create mode 100644 hamilton/caching/cache_key.py create mode 100644 hamilton/caching/fingerprinting.py create mode 100644 hamilton/caching/stores/__init__.py create mode 100644 hamilton/caching/stores/base.py create mode 100644 hamilton/caching/stores/file.py create mode 100644 hamilton/caching/stores/sqlite.py create mode 100644 hamilton/caching/stores/utils.py create mode 100644 tests/caching/__init__.py create mode 100644 tests/caching/test_adapter.py create mode 100644 tests/caching/test_fingerprinting.py create mode 100644 tests/caching/test_integration.py create mode 100644 tests/caching/test_metadata_store.py create mode 100644 tests/caching/test_result_store.py create mode 100644 tests/resources/dynamic_parallelism/parallelism_with_caching.py diff --git a/.gitignore b/.gitignore index 5b48a18c5..5ed52fc7b 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,6 @@ examples/**/hamilton-env **.pkl **.lance **.txn + +# hamilton default caching directory +**/.hamilton_cache/ diff --git a/docs/concepts/_caching/view_run_example.svg b/docs/concepts/_caching/view_run_example.svg new file mode 100644 index 000000000..8f7b0e127 --- /dev/null +++ b/docs/concepts/_caching/view_run_example.svg @@ -0,0 +1,167 @@ + + + + + + +%3 + + +cluster__legend + +Legend + + + +full_cdf + +full_cdf +Series + + + +probability_before_date + +probability_before_date +Series + + + +full_cdf->probability_before_date + + + + + +due_date + +due_date +datetime + + + +full_pdf + +full_pdf +Series + + + +due_date->full_pdf + + + + + +possible_dates + +possible_dates +Series + + + +due_date->possible_dates + + + + + +probability_distribution + +probability_distribution +rv_continuous + + + +probability_distribution->full_pdf + + + + + +full_pdf->full_cdf + + + + + +probability_on_date + +probability_on_date +Series + + + +full_pdf->probability_on_date + + + + + +possible_dates->probability_before_date + + + + + +possible_dates->probability_on_date + + + + + +_due_date_inputs + +start_date +datetime + + + +_due_date_inputs->due_date + + + + + +_full_pdf_inputs + +start_date +datetime +current_date +Optional + + + +_full_pdf_inputs->full_pdf + + + + + +input + +input + + + +function + +function + + + +output + +output + + + +from cache + +from cache + + + diff --git a/docs/concepts/builder.rst b/docs/concepts/builder.rst index ab453d5d6..1ce0afd17 100644 --- a/docs/concepts/builder.rst +++ b/docs/concepts/builder.rst @@ -179,6 +179,24 @@ Adds `DataSaver` and `DataLoader` nodes to your dataflow. This allows to visuali :align: center +with_cache() +-------------- + +This enables Hamilton's caching feature, which allows to automatically store intermediary results and reuse them in subsequent executions to skip computations. Learn more in the :doc:`/concepts/caching` section. + +.. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = ( + driver.Builder() + .with_modules(my_dataflow) + .with_cache() + .build() + ) + + with_adapters() --------------- diff --git a/docs/concepts/caching.rst b/docs/concepts/caching.rst new file mode 100644 index 000000000..d9da62144 --- /dev/null +++ b/docs/concepts/caching.rst @@ -0,0 +1,575 @@ +======== +Caching +======== + +Caching enables storing execution results to be reused in later executions, effectively skipping redundant computations. This speeds up execution and saves resources (computation, API credits, GPU time, etc.), and has applications both for development and production. + +To enable caching, add ``.with_cache()`` to your ``Builder()``. + +.. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = ( + driver.Builder() + .with_module(my_dataflow) + .with_cache() + .build() + ) + + dr.execute([...]) + dr.execute([...]) + + +The first execution will store **metadata** and **results** next to the current directory under ``./.hamilton_cache``. The next execution will retrieve results from cache when possible to skip execution. + +.. note:: + + We highly suggest viewing the :doc:`../how-tos/caching-tutorial` tutorial for a practical introduction to caching. + + +How does it work? +----------------- + +Caching relies on multiple components: + +- **Cache adapter**: decide to retrieve a result or execute the node +- **Metadata store**: store information about past node executions +- **Result store**: store results on disk, it is unaware of other cache components. + +At a high-level, the cache adapter does the following for each node: + +1. Before execution: determine the ``cache_key`` + +2. At execution: + a. if the ``cache_key`` finds a match in the metadata store (cache **hit**), retrieve the ``data_version`` of the ``result``. + b. If there's no match (cache **miss**), execute the node and store the ``data_version`` of the ``result`` in the metadata store. + +3. After execution: if we had to execute the node, store the ``result`` in the result store. + +The caching mechanism is highly performant because it can pass ``data_version`` (small strings) through the dataflow instead of the actual data until a node needs to be executed. + +The result store is a mapping of ``{data_version: result}``. While a ``cache_key`` is unique to determine retrieval or execution, multiple cache keys can point to the same ``data_version``, which avoid storing duplicate results. + +Cache key +~~~~~~~~~ + +Understanding the ``cache_key`` is important to understand why a node is recomputed or not. It is composed of: + +- ``node_name``: name of the node +- ``code_version``: version of the node's code +- ``dependencies_data_versions``: ``data_version`` of each dependency of the node + +.. code-block:: json + + { + "node_name": "processed_data", + "code_version": "c2ccafa54280fbc969870b6baa445211277d7e8cfa98a0821836c175603ffda2", + "dependencies_data_versions": { + "raw_data": "WgV5-4SfdKTfUY66x-msj_xXsKNPNTP2guRhfw==", + "date": "ZWNhd-XNlIF0YV9-2ZXJzaW9u_YGAgKA==", + } + } + +By traversing the cache keys' ``dependencies_data_versions``, we can actually reconstruct the dataflow structure! + +.. warning:: + + Cache keys could be unstable across Python and Hamilton versions (because of new features, bug fixes, etc.). Upgrading Python or Hamilton could require starting with a new empty cache for reliable behavior. + +Observing the cache +------------------- + +Caching is best understood throung interacting with it. Hamilton offers many utilities to observe and introspect the cache manually. + +Logging +~~~~~~~ + +To see how the cache works step-by-step, start your code (script, notebook, etc.) by getting the logger and setting the level to ``DEBUG``. Using ``INFO`` will be less noisy and only log ``GET_RESULT`` and ``EXECUTE_NODE`` events. + +.. code-block:: python + + import logging + + logger = logging.getLogger("hamilton.caching") + logger.setLevel(logging.INFO) + logger.addHandler(logging.StreamHandler()) # this handler will print to the console + +The logs follow the structure ``{node_name}::{task_id}::{actor}::{event_type}::{message}``, omitting empty sections. + + +.. code-block:: console + + # example INFO logs for nodes foo, bar, and baz + foo::result_store::get_result::hit + bar::adapter::execute_node + baz::adapter::execute_node + + +Visualization +~~~~~~~~~~~~~~ + +After ``Driver`` execution, calling ``dr.cache.view_run()`` will create a visualization of the dataflow with results retrieved from the cache highlighted. + +By default, it shows the latest run, but it's possible to view previous runs by passing a ``run_id``. Specify a ``output_file_path`` to save the visualization. + +.. code-block:: python + + # ... define and execute a `Driver` + + # select the 3rd unique run_id + run_id_3 = dr.cache.run_ids[2] + dr.cache.view_run(run_id=run_id_3, output_file_path="cached_run_3.png") + + +.. figure:: _caching/view_run_example.svg + + Visualization produced by ``dr.cache.view_run()``. Retrieved results are outlined. + + +.. note:: + + The method ``.view_run()`` doens't currently support task-based execution or ``Parallelizable/Collect``. + + +.. _caching-structured-logs: + +Structured logs +~~~~~~~~~~~~~~~ + +Structured logs are stored on the ``Driver.cache`` and can be inspected programmatically. By setting ``.with_cache(log_to_file=True)``, structured logs will also be appended to a ``.jsonl`` file as they happen; this is ideal for production usage. + +To access log, use ``Driver.cache.logs()``. You can ``.logs(level=...)`` to ``"info"`` or ``"debug"`` to view only ``GET_RESULT`` and ``EXECUTE_NODE`` or all events. Specifying ``.logs(run_id=...)`` will return logs from a given run, and leaving it empty will returns logs for all executions of this ``Driver``. + +.. code-block:: python + + dr.execute(...) + dr.cache.logs(level="info") + +The shape of the returned object is slightly diffrent if specifying a ``run_id`` or not. Specifying a ``run_id`` will give ``{node_name: List[CachingEvent]}`` + +Requesting ``Driver.cache.logs()`` will return a dictionary with ``run_id`` as key and list of ``CachingEvent`` as values ``{run_id: List[CachingEvent]}``. This is useful for comparing run and verify nodes were properly executed or retrieved. + + +.. code-block:: python + + dr.cache.logs(level="debug", run_id=dr.cache.last_run_id) + # { + # 'raw_data': [CachingEvent(...), ...], + # 'processed_data': [CachingEvent(...), ...], + # 'amount_per_country': [CachingEvent(...), ...] + # } + + dr.cache.logs(level="debug") + # { + # 'run_id_1': [CachingEvent(...), ...], + # 'run_id_2': [CachingEvent(...), ...] + # } + +.. note:: + + When using ``Parallelizable/Collect``, nodes part of the "parallel branches" will have a ``task_id`` key too ``{node_name: {task_id: List[CachingEvent]}}`` while nodes outside branches will remain ``{node_name: List[CachingEvent]}`` + + +.. _cache-result-format: + +Cached result format +--------------------- + +By default, caching uses the ``pickle`` format because it can accomodate almost all Python objects. Although, it has `caveats `_. The ``cache`` decorator allows you to use a different format for a given node (``JSON``, ``CSV``, ``Parquet``, etc.). + +The next snippet caches ``clean_dataset`` as ``parquet``, and ``statistics`` as ``json``. These formats maybe more reliable, efficient, and easier to work with. + +.. code-block:: python + + # my_dataflow.py + import pandas as pd + from hamilton.function_modifiers import cache + + def raw_data(path: str) -> pd.DataFrame: + return pd.read_csv(path) + + @cache(format="parquet") + def clean_dataset(raw_data: pd.DataFrame) -> pd.DataFrame: + raw_data = raw_data.fillna(0) + return raw_data + + @cache(format="json") + def statistics(clean_dataset: pd.DataFrame) -> dict: + return ... + + +.. code-block:: python + + import driver + import my_dataflow + + dr = ( + driver.Builder() + .with_modules(my_dataflow) + .with_cache() + .buid() + ) + + # first execution will product a ``parquet`` file for ``clean_dataset`` + # and a ``json`` file for ``statistics`` + dr.execute(["statistics"]) + # second execution will use these parquet and json files when loading results + dr.execute(["statistics"]) + +.. note:: + + Internally, this uses :doc:`Materializers ` + +Caching behavior +----------------- + +The **caching behavior** refers to the caching logic used to: +- version data +- load and store metadata +- load and store results +- execute or not a node + +The ``DEFAULT`` behavior aims to be easy to use and facilitate iterative development. However, other behavior may be desirble in particular scenarios or when going to production. The behavior can be set node-by-node. + +1. ``DEFAULT``: Try to retrieve results from cache instead of executing the node. Node result and metadata are stored. + +2. ``RECOMPUTE``: Always execute the node / never retrieve from cache. Result and metadata are stored. This can be useful to ensure external data is alawys reloaded. + +3. ``DISABLE``: Act as if caching isn't enabled for this node. Nodes depending on a disabled node will miss metadata for cache retrieval, forcing their re-execution. Useful for disabling caching in parts of the dataflow. + +4. ``IGNORE``: Similar to **Disable**, but downstream nodes will ignore the missing metadata and can successfully retrieve results. Useful to ignore "irrelevant" nodes that shouldn't impact the results (e.g., credentials, API clients, database connections). + +.. seealso:: + + Learn more in the :doc:`/reference/caching/caching-logic` reference section. + +.. note:: + + There are other caching behaviors theoretically possible, but these four should cover most cases. Let us know if you have a use case that is not covered. + + +Setting caching behavior +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The caching behavior can be specified at the node-level via the ``@cache`` function modifier or at the builder-level via ``.with_cache(...)`` arguments. Note that the behavior specified by the ``Builder`` will override the behavior from ``@cache`` since it's closer to execution. + +via ``@cache`` +~~~~~~~~~~~~~~~ + +Below, we set ``raw_data`` to ``RECOMPUTE`` because the file it loads data from may change between executions. After executing and versioning the result of ``raw_data``, if the data didn't change from previous execution, we'll be able to retrieve ``clean_dataset`` and ``statistics`` from cache. + +.. code-block:: python + + # my_dataflow.py + import pandas as pd + from hamilton.function_modifiers import cache + + @cache(behavior="recompute") + def raw_data(path: str) -> pd.DataFrame: + return pd.read_csv(path) + + def clean_dataset(raw_data: pd.DataFrame) -> pd.DataFrame: + raw_data = raw_data.fillna(0) + return raw_data + + def statistics(clean_dataset: pd.DataFrame) -> dict: + return ... + + +via ``Builder().with_cache()`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Equivalently, we could set this behavior via the ``Builder``. You can pass a list of node names to the keyword arguments ``recompute``, ``ignore``, and ``disable``. Using ``True`` to enable that behavior for all nodes. For example, using ``recompute=True`` will force execution of all nodes and store their results in cache. Having ``disable=True`` is equivalent to not having the ``.with_cache()`` clause. + +.. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = ( + driver.Builder() + .with_modules(my_dataflow) + .with_cache(recompute=["raw_data"]) + .build() + ) + +Set a default behavior +~~~~~~~~~~~~~~~~~~~~~~ + +By default, caching is "opt-out" meaning all nodes are cached unless specified otherwise. To make it "opt-in", where only the specified nodes are cached, set ``default_behavior="disable"``. You can also try different default behaviors. + +.. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = ( + driver.Builder() + .with_modules(my_dataflow) + .with_cache( + default=["raw_data", "statistics"], # set behavior DEFAULT + default_behavior="disable" # all other nodes are DISABLE + ) + .build() + ) + + +Code version +------------ + +The ``code_version`` of a node is determined by hashing its source code, ignoring docstring and comments. + +Importantly, Hamilton will not version nested function calls. If you edit utility functions or upgrade Python libraries, the cache might incorrectly assume the code to be the same. + +For example, take the following function ``foo``: + +.. code-block:: python + + def _increment(x): + return x + 1 + + def foo(): + return _increment(13) + + # foo's code version: 129064d4496facc003686e0070967051ceb82c354508a58440910eb82af300db + + +Despite editing the nested ``_increment()``, we get the same ``code_version`` because the content of ``foo()`` hasn't changed. + +.. code-block:: python + + def _increment(x): + return x + 2 + + def foo(): + return _increment(13) + + # foo's code version: 129064d4496facc003686e0070967051ceb82c354508a58440910eb82af300db + +In that case, ``foo()`` should return ``13 + 2`` instead of ``13 + 1``. Unaware of the change in ``_increment()``, the cache will find a ``cache_key`` match and return ``13 + 1``. + +A solution is to set the caching behavior to ``RECOMPUTE`` to force execute ``foo()``. Another is to delete stored metadata or results to force re-execution. + +Data version +------------ + +Caching requires the ability to uniquely identify data (e.g., create a hash). By default, all Python primitive types (``int``, ``str``, ``dict``, etc.) are supported and more types can be added via extensions (e.g., ``pandas``). For types not explicitly supported, caching can still function by versioning the object's internal ``__dict__`` instead. However, this could be expensive to compute or less reliable than alternatives. + +Recursion depth +~~~~~~~~~~~~~~~ + +To version complex objects, we recursively hash its values. For example, versioning an object ``List[Dict[str, float]]`` involves hashing all keys and values of all dictionaries. Versioning complex objects with large ``__dict__`` state can become expensive. + +In practice, we need to need a maximum recursion depth because there's a trade-off between the computational cost of hashing data and how accurately it uniquely identifies data (reduce hashing collisions). + +Here's how to set the max depth: + +.. code-block:: python + + from hamilton.io import fingerprinting + fingerprinting.set_max_depth(depth=3) + + +Support additional types +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Additional types can be supported by registering a hashing function via the module ``hamilton.io.fingerprinting``. It uses `@functools.singledispatch `_ to register the hashing function per Python type. The function must return a ``str``. The code snippets shows how to support polars ``DataFrame``: + +.. code-block:: python + + import polars as pl + from hamilton.io import fingerprinting + + # specify the type via the decorator + @fingerprinting.hash_value.register(pl.DataFrame) + def hash_polars_dataframe(obj, *args, **kwargs) -> str: + """Convert a polars dataframe to a list of row hashes, then hash the list. + We consider that row order matters. + """ + # obj is of type `pl.DataFrame` + hash_per_row = obj.hash_rows(seed=0) + # fingerprinting.hash_value(...) will automatically hash primitive Python types + return fingerprinting.hash_value(hash_per_row) + +Alternatively, you can register functions without using decorators. + +.. code-block:: python + + from hamilton.io import fingerprinting + + def hash_polars_dataframe(obj, *args, **kwargs) -> str: ... + + fingerprinting.hash_value.register(pl.DataFrame, hash_polars_dataframe) + + +If you want to override the base case, the one defined by the function ``hash_value()``, you can do so by registering a function for the type ``object``. + +.. code-block:: python + + @fingerprinting.hash_value.register(object) + def hash_object(obj, *args, **kwargs) -> str: ... + + +Storage +------- + +The caching feature is powered by two data storages: + +- **Metadata store**: It contains information about past ``Driver`` executions (**code version**, **data version**, run id, etc.). From this metadata, Hamilton determines if a node needs to be executed or not. This metadata is generally lightweight. + +- **Result store**: It's a key-value store that maps a **data version** to a **result**. It's completely unaware of nodes, executions, etc. and simply holds the **results**. The result store can significantly grow in size depending on your usage. By default, all results are pickled, but :ref:`other formats are possible `. + + +Setting the cache path +~~~~~~~~~~~~~~~~~~~~~~ + +By default, the **metadata** and **results** are stored under a new subdirectory ``./.hamilton_cache/``, next to the current directory. Alternatively, you can set a path via ``.with_cache(path=...)`` that will be applied to both stores. + + +By project +^^^^^^^^^^ +Centralizing your cache by project is useful when you have nodes that are reused across multiple dataflows (e.g., training and inference ML pipelines, feature engineering). + + +.. code-block:: python + + # training_script.py + from hamilton import driver + import training + + cache_path = "/path/to/project/hamilton_cache" + train_dr = driver.Builder().with_modules(training).with_cache(path=cache_path).build() + + # inference_script.py + from hamilton import driver + import inference + + cache_path = "/path/to/project/hamilton_cache" + predict_dr = driver.Builder().with_modules(inference).with_cache(path=cache_path).build() + + +Globally +^^^^^^^^^^ + +Using a global cache is easier storage management. Since the metadata and the results for *all* your Hamilton dataflows are in one place, it can be easier to cleanup disk space. + +.. code-block:: python + + import pathlib + from hamilton import driver + import my_dataflow + + # set the cache under the user's global directory for any operating system + # The `Path` is converted to a string. + cache_path = str(pathlib.expanduser().joinpath("/.hamilton_cache")) + dr = driver.Builder().with_module(my_dataflow).with_cache(path=cache_path).build() + +.. hint:: + + It can be a good idea to store the cache path in an environment variable. + +Separate locations +^^^^^^^^^^^^^^^^^^ + +If you want the metadata and result stores to be at different location, you can instantiate and pass them to ``.with_cache()``. In that case, ``.with_cache()``'s ``path`` parameter will be ignored. + +.. code-block:: python + + from hamilton import driver + from hamitlon.io.store import SQLiteMetadataStore, ShelveResultStore + + metadata_store = SQLiteMetadataStore(path="~/.hamilton_cache") + result_store = ShelveResultStore(path="/path/to/my/project") + + dr = ( + driver.Builder() + .with_modules(dataflow) + .with_cache( + metadata_store=metadata_store, + result_store=result_store, + ) + .build() + ) + + +Inspect storage +~~~~~~~~~~~~~~~ + +It is possible to directly interact with the metadata and result stores either by creating them or via ``Driver.cache``. + + +.. code-block:: python + + from hamitlon.io.store import SQLiteMetadataStore, ShelveResultStore + + metadata_store = SQLiteMetadataStore(path="~/.hamilton_cache") + result_store = ShelveResultStore(path="/path/to/my/project") + + metadata_store.get(context_key=...) + result_store.get(data_version=...) + + +.. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = ( + driver.Builder() + .with_modules(dataflow) + .with_cache() + .build() + ) + + dr.cache.metadata_store.get(context_key=...) + dr.cache.result_store.get(data_version=...) + + +A useful pattern is using the ``Driver.cache`` state or `structured logs ` to retrieve a **data version** and query the **result store**. + +.. code-block:: python + + from hamilton import driver + from hamilton.lifecycle.caching import CachingEventType + import my_dataflow + + dr = ( + driver.Builder() + .with_modules(dataflow) + .with_cache() + .build() + ) + + dr.execute(["amount_per_country"]) + + # via `cache.data_versions`; this points to the latest run + data_version = dr.cache.data_versions["amount_per_country"] + stored_result = dr.cache.result_store.get(data_version) + + # via structured logs; this allows to query any run + run_id = ... + for event in dr.cache.logs(level="debug")[run_id]: + if ( + event.event_type == CachingEventType.SET_RESULT + and event.node_name == "amount_per_country" + ): + data_version = event.value + break + + stored_result = dr.cache.result_store(data_version) + + +Roadmap +------- + +Caching is a significant Hamilton feature and there are plans to expand it. Here are some ideas and areas for development. Feel free comment on them or make other suggestions via `Slack `_ or GitHub! + +- **Hamilton UI integration**: caching introduces the concept of ``data_version``. This metadata could be captured by the Hamilton UI to show how different values are used across dataflow executions. This would be particularly useful for experiment tracking and lineage. +- **Distributed caching support**: the initial release supports multithreading and multiprocessing on a single machine. For distributed execution, we will need ``ResultStore`` and ``MetadataStore`` that can be remote and are safe for concurrent access. +- **Integrate with remote execution** (Ray, Skypilot, Modal, Runhouse): facilitate a pattern where the dataflow is executed locally, but some nodes can selectively be executed remotely and have their results cached locally. +- **async support**: Support caching with ``AsyncDriver``. This requires a significant amount of code, but the core logic shouldn't change much. +- **cache eviction**: Allow to set up a max storage (in size or number of items) or time-based policy to delete data from the metadata and result stores. This would help with managing the cache size. +- **more store backends**: The initial release includes backend supported by the Python standard library (SQLite metadata and file-based results). Could support more backends via `fsspec `_ (AWS, Azure, GCP, Databricks, etc.) +- **support more types**: Include specialized hashing functions for complex objects from popular libraries. This can be done through Hamilton extensions. diff --git a/docs/concepts/index.rst b/docs/concepts/index.rst index 80d75ddea..dceaa0828 100644 --- a/docs/concepts/index.rst +++ b/docs/concepts/index.rst @@ -15,6 +15,7 @@ concepts that makes Hamilton unique and powerful. materialization function-modifiers builder + caching function-modifiers-advanced parallel-task ui diff --git a/docs/conf.py b/docs/conf.py index 2ace01545..f6f32968c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,11 +28,13 @@ extensions = [ "sphinx.ext.autodoc", "sphinx.ext.autosummary", - "myst_parser", + "myst_nb", "sphinx_sitemap", "docs.data_adapters_extension", ] +nb_execution_mode = "off" + # for the sitemap extension --- # check if the current commit is tagged as a release (vX.Y.Z) and set the version GIT_TAG_OUTPUT = subprocess.check_output(["git", "tag", "--points-at", "HEAD"]) diff --git a/docs/how-tos/cache-nodes.rst b/docs/how-tos/cache-nodes.rst deleted file mode 100644 index 043754ddb..000000000 --- a/docs/how-tos/cache-nodes.rst +++ /dev/null @@ -1,9 +0,0 @@ -====================== -Caching results -====================== - -Sometimes it is convenient to cache intermediate nodes. This is especially useful during development. - -For example, if a particular node takes a long time to calculate (perhaps it extracts data from an outside source or performs some heavy computation), you can annotate it with "cache" tag. The first time the DAG is executed, that node will be cached to disk. If then you do some development on any of the downstream nodes, the subsequent executions will load the cached node instead of repeating the computation. - -See the examples here `here `_. diff --git a/docs/how-tos/caching-tutorial.ipynb b/docs/how-tos/caching-tutorial.ipynb new file mode 100644 index 000000000..b25ff2909 --- /dev/null +++ b/docs/how-tos/caching-tutorial.ipynb @@ -0,0 +1,4018 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Caching\n", + "In Hamilton, **caching** broadly refers to \"reusing results from previous executions to skip redundant computation\". If you change code or pass new data, it will automatically determine which results can be reused and which nodes need to be re-executed. This improves execution speed and reduces resource usage (computation, API credits, etc.).\n", + "\n", + "```{note}\n", + "\n", + "Open the notebook in [Google Colab](https://colab.research.google.com/github/DAGWorks-Inc/hamilton/blob/main/examples/caching/tutorial.ipynb) for an interactive version and better syntax highlighting.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Throughout this tutorial, we'll be using the Hamilton notebook extension to define dataflows directly in the notebook ([see tutorial](https://github.com/DAGWorks-Inc/hamilton/blob/main/examples/jupyter_notebook_magic/example.ipynb)).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton import driver\n", + "\n", + "# load the notebook extension\n", + "%reload_ext hamilton.plugins.jupyter_magic" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We import the `logging` module and get the logger from `hamilton.caching`. With the level set to ``INFO``, we'll see ``GET_RESULT`` and ``EXECUTE_NODE`` cache events as they happen." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logger = logging.getLogger(\"hamilton.caching\")\n", + "logger.setLevel(logging.INFO)\n", + "logger.addHandler(logging.StreamHandler())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next cell deletes the cached data to ensure this notebook can be run from top to bottom without any issues." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "\n", + "shutil.rmtree(\"./.hamilton_cache\", ignore_errors=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basics\n", + "\n", + "Throughout this notebook, we'll use the same simple dataflow that processes transactions in various locations and currencies.\n", + "\n", + "We use the cell magic `%%cell_to_module` from the Hamilton notebook extension. It will convert the content of the cell into a Python module that can be loaded by Hamilton. The `--display` flag allows to visualize the dataflow." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module basics_module --display\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.73\n", + " return df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, we build the ``Driver`` with caching enabled and execute the dataflow." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 541.7622\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 426.7288\n" + ] + } + ], + "source": [ + "basics_dr = driver.Builder().with_modules(basics_module).with_cache().build()\n", + "\n", + "basics_results_1 = basics_dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(basics_results_1[\"processed_data\"].head())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can view what values were retrieved from the cache using `dr.cache.view_run()`. Since this was the first execution, nothing is retrieved." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "basics_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On the second execution, `processed_data` is retrieved from cache as reported in the logs and highlighted in the visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 541.7622\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 426.7288\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "basics_results_2 = basics_dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(basics_results_2[\"processed_data\"].head())\n", + "print()\n", + "basics_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Understanding the `cache_key`\n", + "\n", + "The Hamilton cache stores results using a `cache_key`. It is composed of the node's name (`node_name`), the code that defines it (`code_version`), and its data inputs (`data_version` of its dependencies).\n", + "\n", + "For example, the cache keys for the previous cells are:\n", + "\n", + "```json\n", + "{\n", + " \"node_name\": \"raw_data\",\n", + " \"code_version\": \"9d727859b9fd883247c3379d4d25a35af4a56df9d9fde20c75c6375dde631c68\",\n", + " \"dependencies_data_versions\": {} // it has no dependencies\n", + "}\n", + "{\n", + " \"node_name\": \"processed_data\",\n", + " \"code_version\": \"c9e3377d6c5044944bd89eeb7073c730ee8707627c39906b4156c6411f056f00\",\n", + " \"dependencies_data_versions\": {\n", + " \"cutoff_date\": \"WkGjJythLWYAIj2Qr8T_ug==\", // input value\n", + " \"raw_data\": \"t-BDcMLikFSNdn4piUKy1mBcKPoEsnsYjUNzWg==\" // raw_data's result\n", + " }\n", + "}\n", + "```\n", + "\n", + "Results could be successfully retrieved because nodes in the first execution and second execution shared the same `cache_key`.\n", + "\n", + "The `cache_key` objects are internal and you won't have to interact with them directly. However, keep that concept in mind throughout this tutorial. Towards the end, we show how to manually handle the `cache_key` for debugging." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding a node\n", + "\n", + "Let's say you're iteratively developing your dataflow and you add a new node. Here, we copy the previous module into a new module named `adding_node_module` and define the node `amount_per_country`.\n", + "\n", + "> In practice, you would edit the cell directly, but this makes the notebook easier to read and maintain" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module adding_node_module --display\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.73\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We build a new `Driver` with `adding_node_module` and execute the dataflow. You'll notice that `raw_data` and `processed_data` are retrieved and only `amount_per_country` is executed." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 968.491\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adding_node_dr = driver.Builder().with_modules(adding_node_module).with_cache().build()\n", + "\n", + "adding_node_results = adding_node_dr.execute(\n", + " [\"processed_data\", \"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(adding_node_results[\"amount_per_country\"].head())\n", + "print()\n", + "adding_node_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Even though this is the first execution of `adding_node_dr` and the module `adding_node_module`, the cache contains results for `raw_data` and `processed_data`. We're able to retrieve values because they have the same cache keys (code version and dependencies data versions).\n", + "\n", + "This means you can reuse cached results across dataflows. This is particularly useful with training and inference machine learning pipelines." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing inputs\n", + "\n", + "We reuse the same dataflow `adding_node_module`, but change the input `cutoff_date` from\n", + "`\"2024-09-01\"` to `\"2024-09-11\"`. \n", + "\n", + "\n", + "This new input forces `processed_data` to be re-executed. This produces a new result for `processed_data`, which cascades and also forced `amount_per_country` to be re-executed." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "USA 729.9\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_inputs_dr = driver.Builder().with_modules(adding_node_module).with_cache().build()\n", + "\n", + "changing_inputs_results_1 = changing_inputs_dr.execute(\n", + " [\"processed_data\", \"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-11\"}\n", + ")\n", + "print()\n", + "print(changing_inputs_results_1[\"amount_per_country\"].head())\n", + "print()\n", + "changing_inputs_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we execute with the `cutoff_date` value `\"2024-09-05\"`, which forces `processed_data` to be executed." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 968.491\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_inputs_results_2 = changing_inputs_dr.execute(\n", + " [\"processed_data\", \"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-05\"}\n", + ")\n", + "print()\n", + "print(changing_inputs_results_2[\"amount_per_country\"].head())\n", + "print()\n", + "changing_inputs_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the cache could still retrieve `amount_per_country`. This is because `processed_data` return a value that had been cached previously (in the `Adding a node` section).\n", + "\n", + "In concrete terms, filtering rows by the date `\"2024-09-05\"` or `\"2024-09-01\"` includes the same rows and produces the same dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 541.7622\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 426.7288\n", + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 541.7622\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 426.7288\n" + ] + } + ], + "source": [ + "print(adding_node_results[\"processed_data\"])\n", + "print()\n", + "print(changing_inputs_results_2[\"processed_data\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing code\n", + "As you develop your dataflow, you will need to edit upstream nodes. Caching will automatically detect code changes and determine which node needs to be re-executed. In `processed_data()`, we'll change the conversation rate from `0.73` to `0.71`.\n", + "\n", + "> NOTE. changes to docstrings and comments `#` are ignored when versioning a node." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module changing_code_module\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 # <- VALUE CHANGED FROM module_2\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We need to execute `processed_data` because the code change created a new `cache_key` and led to a cache miss. Then, `processed_data` returns a previously unseen value, forcing `amount_per_country` to also be re-executed" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 941.957\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_code_dr_1 = driver.Builder().with_modules(changing_code_module).with_cache().build()\n", + "\n", + "changing_code_results_1 = changing_code_dr_1.execute(\n", + " [\"processed_data\", \"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(changing_code_results_1[\"amount_per_country\"].head())\n", + "print()\n", + "changing_code_dr_1.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We make another code change to `processed_data` to accomodate currency conversion for Brazil and Mexico." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module changing_code_module_2\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18 # <- LINE ADDED\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05 # <- LINE ADDED\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again, the code change forces `processed_data` to be executed." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 941.957\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_code_dr_2 = driver.Builder().with_modules(changing_code_module_2).with_cache().build()\n", + "\n", + "changing_code_results_2 = changing_code_dr_2.execute([\"processed_data\",\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(changing_code_results_2[\"amount_per_country\"].head())\n", + "print()\n", + "changing_code_dr_2.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, `amount_per_country` can be retrieved because `processed_data` returned a previously seen value.\n", + "\n", + "In concrete terms, adding code to process currency from Brazil and Mexico didn't change the `processed_data` result because it only includes data from the USA and Canada.\n", + "\n", + "> NOTE. This is similar to what happened at the end of the section **Changing inputs**." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n", + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n" + ] + } + ], + "source": [ + "print(changing_code_results_1[\"processed_data\"])\n", + "print()\n", + "print(changing_code_results_2[\"processed_data\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing external data\n", + "\n", + "Hamilton's caching mechanism uses the node's `code_version` and its dependencies `data_version` to determine if the node needs to be executed or the result can be retrieved from cache. By default, it assumes [idempotency](https://www.astronomer.io/docs/learn/dag-best-practices#review-idempotency) of operations.\n", + "\n", + "This section covers how to handle node with external effects, such as reading or writing external data.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Idempotency\n", + "\n", + "To illustrate idempotency, let's use this minimal dataflow which has a single node that returns the current date and time:\n", + "\n", + "```python\n", + "import datetime\n", + "\n", + "def current_datetime() -> datetime.datetime:\n", + " return datetime.datetime.now()\n", + "```\n", + "\n", + "The first execution will execute the node and store the resulting date and time. On the second execution, the cache will read the stored result instead of re-executing. Why? Because the `code_version` is the same and the dependencies `data_version` (it has no dependencies) haven't changed.\n", + "\n", + "A similar situation occurs when reading from external data, as shown here:\n", + "\n", + "```python\n", + "import pandas as pd\n", + "\n", + "def dataset(file_path: str) -> pd.DataFrame:\n", + " return pd.read_csv(file_path)\n", + "```\n", + "\n", + "Here, the code of `dataset()` and the value for `file_path` can stay the same, but the file itself could be updated (e.g., new rows added).\n", + "\n", + "The next sections show how to always re-execute a node and ensure the latest data is used. The `DATA` constant is modified with transactions in Brazil and Mexico to simulate `raw_data` loading a new dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module changing_external_module\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At execution, we see `raw_data` being retrieved along with all downstream nodes. Also, we note that the printed results don't include Brazil nor Mexico." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 941.957\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_external_dr = driver.Builder().with_modules(changing_external_module).with_cache().build()\n", + "\n", + "changing_external_results = changing_external_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(changing_external_results[\"amount_per_country\"].head())\n", + "print()\n", + "changing_external_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `.with_cache()` to specify caching behavior\n", + "Here, we build a new `Driver` with the same `changing_external_module`, but we specify in `.with_cache()` to always recompute `raw_data`. \n", + "\n", + "The visualization shows that `raw_data` was executed, and because of the new data, all downstream nodes also need to be executed. The results now include Brazil and Mexico." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_external_with_cache_dr = driver.Builder().with_modules(changing_external_module).with_cache(recompute=[\"raw_data\"]).build()\n", + "\n", + "changing_external_with_cache_results = changing_external_with_cache_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(changing_external_with_cache_results[\"amount_per_country\"].head())\n", + "print()\n", + "changing_external_with_cache_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `@cache` to specify caching behavior\n", + "Another way to specify the `RECOMPUTE` behavior is to use the `@cache` decorator." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module changing_external_decorator_module\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import cache\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@cache(behavior=\"recompute\")\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We build a new `Driver` with `changing_external_cache_decorator_module`, which includes the `@cache` decorator. Note that we don't specify anything in `.with_cache()`." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_external_decorator_dr = (\n", + " driver.Builder()\n", + " .with_modules(changing_external_decorator_module)\n", + " .with_cache()\n", + " .build()\n", + ")\n", + "\n", + "changing_external_decorator_results = changing_external_decorator_dr.execute(\n", + " [\"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(changing_external_decorator_results[\"amount_per_country\"].head())\n", + "print()\n", + "changing_external_decorator_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that `raw_data` was re-executed. Then, `processed_data` and `amount_per_country` can be retrieved because they were produced just before by the `changing_external_with_cache_dr`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### When to use `@cache` vs. `.with_cache()`?\n", + "\n", + "Specifying the caching behavior via `.with_cache()` or `@cache` is entirely equivalent. There are benefits to either approach:\n", + "\n", + "- `@cache`: specify behavior at the dataflow-level. The behavior is tied to the node and will be picked up by all `Driver` loading the module. This can prevent errors or unexpected behaviors for users of that dataflow.\n", + "\n", + "- `.with_cache()`: specify behavior at the `Driver`-level. Gives the flexiblity to change the behavior without modifying the dataflow code and committing changes. You might be ok with `DEFAULT` during development, but want to ensure `RECOMPUTE` in production.\n", + "\n", + "Importantly, the behavior specified in `.with_cache(...)` overrides whatever is in `@cache` because it is closer to execution. For example, having `.with_cache(default=[\"raw_data\"])` `@cache(behavior=\"recompute\")` would force `DEFAULT` behavior.\n", + "\n", + "> ⛔ **Important**: Using the `@cache` decorator alone doesn't enable caching; adding `.with_cache()` to the `Builder` does. The decorator is only a mean to specify special behaviors for a node.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Force recompute all\n", + "By specifying `.with_cache(recompute=True)`, you are setting the behavior `RECOMPUTE` for all nodes. This forces recomputation, which is useful for producing a \"cache refresh\" with up-to-date values." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recompute_all_dr = (\n", + " driver.Builder()\n", + " .with_modules(changing_external_decorator_module)\n", + " .with_cache(recompute=True)\n", + " .build()\n", + ")\n", + "\n", + "recompute_all_results = recompute_all_dr.execute(\n", + " [\"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(recompute_all_results[\"amount_per_country\"].head())\n", + "print()\n", + "recompute_all_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that all nodes were recomputed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting default behavior\n", + "\n", + "Once you enable caching using `.with_cache()`, it is a \"opt-out\" feature by default. This means all nodes are cached unless you set the `DISABLE` behavior via `@cache` or `.with_cache(disable=[...])`. This can become difficult to manage as the number of nodes increases. \n", + "\n", + "You can make it an \"opt-in\" feature by setting `default_behavior=\"disable\"` in `.with_cache()`. This way, you're using caching, but only for nodes explicitly specified in `@cache` or `.with_cache()`.\n", + "\n", + "Here, we build a `Driver` with the `changing_external_decorator_module`, where `raw_data` was set to have behavior `RECOMPUTE`, and set the default behavior to `DISABLE`." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "default_behavior_dr = (\n", + " driver.Builder()\n", + " .with_modules(changing_external_decorator_module)\n", + " .with_cache(default_behavior=\"disable\")\n", + " .build()\n", + ")\n", + "\n", + "default_behavior_results = default_behavior_dr.execute(\n", + " [\"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(default_behavior_results[\"amount_per_country\"].head())\n", + "print()\n", + "default_behavior_dr.cache.view_run()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "default_behavior_dr.cache.behaviors[default_behavior_dr.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Materializers\n", + "\n", + "> NOTE. You can skip this section if you're not using materializers.\n", + "\n", + "`DataLoader` and `DataSaver` (collectively \"materializers\") are special Hamilton nodes that connect your dataflow to external data (files, databases, etc.). These constructs are safe to use with caching and are complementary.\n", + "\n", + "**Caching**\n", + "- writing and reading shorter-term data to be used with the dataflow\n", + "- strong connection between the code and the data\n", + "- automatically handle multiple versions of the same dataset\n", + "\n", + "**Materializers**\n", + "- robust mechanism to read/write data from many sources\n", + "- data isn't necessarily meant to be used with Hamilton (e.g., loading from a warehouse, outputting a report).\n", + "- typically outputs to a static destination; each write overwrites the previous stored dataset.\n", + "\n", + "The next cell uses `@dataloader` and `@datasaver` decorators. In the visualization, we see the added `raw_data.loader` and `saved_data` nodes." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module materializers_module -d\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import dataloader, datasaver\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@dataloader()\n", + "def raw_data() -> tuple[pd.DataFrame, dict]:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " data = pd.DataFrame(DATA)\n", + " metadata = {\"source\": \"notebook\", \"format\": \"json\"}\n", + " return data, metadata\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()\n", + "\n", + "@datasaver()\n", + "def saved_data(amount_per_country: pd.DataFrame) -> dict:\n", + " amount_per_country.to_parquet(\"./saved_data.parquet\")\n", + " metadata = {\"source\": \"notebook\", \"format\": \"parquet\"}\n", + " return metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we build a `Driver` as usual. " + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data.loader::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers_dr = (\n", + " driver.Builder()\n", + " .with_modules(materializers_module)\n", + " .with_cache()\n", + " .build()\n", + ")\n", + "\n", + "materializers_results = materializers_dr.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(materializers_results[\"amount_per_country\"].head())\n", + "print()\n", + "materializers_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We execute the dataflow a second time to show that loaders and savers are just like any other node; they can be cached and retrieved." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data.loader::result_store::get_result::hit\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers_results = materializers_dr.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(materializers_results[\"amount_per_country\"].head())\n", + "print()\n", + "materializers_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Usage patterns\n", + "\n", + "Here are a few common scenarios:\n", + "\n", + "**Loading data is expensive**: Your dataflow uses a `DataLoader` to get data from Snowflake. You want to load it once and cache it. When executing your dataflow, you want to use your cached copy to save query time, egress costs, etc.\n", + "- Use the `DEFAULT` caching behavior for loaders.\n", + "\n", + "**Only save new data**: You run the dataflow multiple times (maybe with different parameters or on a schedule) and only want to write to destination when the data changes.\n", + "- Use the `DEFAULT` caching behavior for savers.\n", + "\n", + "**Always read the latest data**: You want to use caching, but also ensure the dataflow always uses the latest data. This involves executing the `DataLoader` every time, get the data in-memory, version it, and then determine what needs to be executed (see **Changing external data**).\n", + "- Use the `RECOMPUTE` caching behavior for loaders.\n", + "\n", + "Use the parameters `default_loader_behavior` or `default_saver_behavior` of the `.with_cache()` clause to specify the behavior for all loaders or savers.\n", + "\n", + "> NOTE. The **Caching + materializers tutorial** notebook details how to achieve granular control over loader and saver behaviors." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data.loader::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers_dr_2 = (\n", + " driver.Builder()\n", + " .with_modules(materializers_module)\n", + " .with_cache(\n", + " default_loader_behavior=\"recompute\",\n", + " default_saver_behavior=\"disable\"\n", + " )\n", + " .build()\n", + ")\n", + "\n", + "materializers_results_2 = materializers_dr_2.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(materializers_results_2[\"amount_per_country\"].head())\n", + "print()\n", + "materializers_dr_2.cache.view_run()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data.loader': ,\n", + " 'raw_data': ,\n", + " 'saved_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers_dr_2.cache.behaviors[materializers_dr_2.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing the cache format\n", + "\n", + "By default, results are stored in ``pickle`` format. It's a convenient default but [comes with caveats](https://grantjenks.com/docs/diskcache/tutorial.html#caveats). You can use the `@cache` decorator to specify another file format for storing results.\n", + "\n", + "By default this includes:\n", + "\n", + "- `json`\n", + "- `parquet`\n", + "- `csv`\n", + "- `excel`\n", + "- `file`\n", + "- `feather`\n", + "- `orc`\n", + "\n", + "This feature uses `DataLoader` and `DataSaver` under the hood and supports all of the same formats (including your custom ones, as long as they take a `path` attribute).\n", + "\n", + "> This is an area of active development. Feel free to share suggestions and feedback!\n", + "\n", + "The next cell sets `processed_data` to be cached using the `parquet` format." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module cache_format_module\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import cache\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "@cache(format=\"parquet\")\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.Series:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When executing the dataflow, we see `raw_data` recomputed because it's a dataloader. The result for `processed_data` will be retrieved, but it will be saved again as `.parquet` this time. " + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "country\n", + "Canada 941.957\n", + "USA 1719.240\n", + "Name: amound_in_usd, dtype: float64\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "Series\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cache_format_dr = driver.Builder().with_modules(cache_format_module).with_cache().build()\n", + "\n", + "cache_format_results = cache_format_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(cache_format_results[\"amount_per_country\"].head())\n", + "print()\n", + "cache_format_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, under the `./.hamilton_cache`, there will be two results of the same name, one with the `.parquet` extension and one without. The one without is actually a pickeld `DataLoader` to retrieve the `.parquet` file.\n", + "\n", + "You can access the path programmatically via the `result_store._path_from_data_version(...)` method." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_version = cache_format_dr.cache.data_versions[cache_format_dr.cache.last_run_id][\"processed_data\"]\n", + "parquet_path = cache_format_dr.cache.result_store._path_from_data_version(data_version).with_suffix(\".parquet\")\n", + "parquet_path.exists()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introspecting the cache\n", + "The `Driver.cache` stores information about all executions over its lifetime. Previous `run_id` are available through `Driver.cache.run_ids` and can be used in tandem without other utility functions:\n", + "\n", + "- Resolve the node caching behavior (e.g., \"recompute\")\n", + "- Access structured logs\n", + "- Visualize the cache execution\n", + "\n", + "Also, `Driver.cache.last_run_id` is a shortcut to the most recent execution." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cache_format_dr.cache.resolve_behaviors(cache_format_dr.cache.last_run_id)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "processed_data::adapter::resolve_behavior\n", + "processed_data::adapter::set_cache_key\n", + "processed_data::adapter::get_cache_key::hit\n", + "processed_data::adapter::get_data_version::miss\n", + "processed_data::metadata_store::get_data_version::miss\n", + "processed_data::adapter::execute_node\n", + "processed_data::adapter::set_data_version\n", + "processed_data::metadata_store::set_data_version\n", + "processed_data::adapter::get_cache_key::hit\n", + "processed_data::adapter::get_data_version::hit\n", + "processed_data::result_store::set_result\n", + "processed_data::adapter::get_data_version::hit\n", + "processed_data::adapter::resolve_behavior\n" + ] + } + ], + "source": [ + "run_logs = cache_format_dr.cache.logs(cache_format_dr.cache.last_run_id, level=\"debug\")\n", + "for event in run_logs[\"processed_data\"]:\n", + " print(event)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "Series\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# for `.view_run()` passing no parameter is equivalent to the last `run_id`\n", + "cache_format_dr.cache.view_run(cache_format_dr.cache.last_run_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Interactively explore runs\n", + "By using `ipywidgets` we can easily build a widget to iterate over `run_id` values and display cache information. Below, we create a `Driver` and execute it a few times to generate data then inspect it with a widget." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::result_store::get_result::hit\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "data": { + "text/plain": [ + "{'amount_per_country': Series([], Name: amound_in_usd, dtype: float64)}" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactive_dr = driver.Builder().with_modules(cache_format_module).with_cache().build()\n", + "\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-05\"})\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-10\"})\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-11\"})\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-13\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following cell allows you to click-and-drag or use arrow-keys to navigate" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a9785e33191453bac0b952ce1f80ef3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(SelectionSlider(description='run_id', options=('101f1759-82c3-416b-875b-e184b765af3c', '…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import display\n", + "from ipywidgets import SelectionSlider, interact\n", + "\n", + "\n", + "@interact(run_id=SelectionSlider(options=interactive_dr.cache.run_ids))\n", + "def iterate_over_runs(run_id):\n", + " display(interactive_dr.cache.data_versions[run_id])\n", + " display(interactive_dr.cache.view_run(run_id=run_id))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Managing storage\n", + "### Setting the cache `path`\n", + "\n", + "By default, metadata and results are stored under `./.hamilton_cache`, relative to the current directory at execution time. You can also manually set the directory via `.with_cache(path=...)` to isolate or centralize cache storage between dataflows or projects.\n", + "\n", + "Running the next cell will create the directory `./my_other_cache`." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "manual_path_dr = driver.Builder().with_modules(cache_format_module).with_cache(path=\"./my_other_cache\").build()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiating the `result_store` and `metadata_store`\n", + "If you need to store metadata and results in separate locations, you can do so by instantiating the `result_store` and `metadata_store` manually with their own configuration. In this case, setting `.with_cache(path=...)` would be ignored." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton.caching.stores.file import FileResultStore\n", + "from hamilton.caching.stores.sqlite import SQLiteMetadataStore\n", + "\n", + "result_store = FileResultStore(path=\"./results\")\n", + "metadata_store = SQLiteMetadataStore(path=\"./metadata\")\n", + "\n", + "manual_stores_dr = (\n", + " driver.Builder()\n", + " .with_modules(cache_format_module)\n", + " .with_cache(\n", + " result_store=result_store,\n", + " metadata_store=metadata_store,\n", + " )\n", + " .build()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Deleting data and recovering storage\n", + "As you use caching, you might be generating a lot of data that you don't need anymore. One straightforward solution is to delete the entire directory where metadata and results are stored. \n", + "\n", + "You can also programmatically call `.delete_all()` on the `result_store` and `metadata_store`, which should reclaim most storage. If you delete results, make sure to also delete metadata. The caching mechanism should figure it out, but it's safer to keep them in sync." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "manual_stores_dr.cache.metadata_store.delete_all()\n", + "manual_stores_dr.cache.result_store.delete_all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage patterns\n", + "\n", + "As demonstrated here, caching works great in a notebook environment.\n", + "\n", + "- In addition to iteration speed, caching allows you to restart your kernel or shutdown your computer for the day without worry. When you'll come back, you will still be able to retrieve results from cache.\n", + "\n", + "- A similar benefit is the ability resume execution between environments. For example, you might be running Hamilton in a script, but when a bug happens you can reload these values in a notebook and investigate.\n", + "\n", + "- Caching works great with other adapters like the `HamiltonTracker` that powers the Hamilton UI and the `MLFlowTracker` for experiment tracking.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🚧 INTERNALS\n", + "If you're curious the following sections provide details about the caching internals. These APIs are not public and may change without notice." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Manually retrieve results\n", + "Using the `Driver.cache` you can directly retrieve results from previous executions. The cache stores \"data versions\" which are keys for the `result_store`. \n", + "\n", + "Here, we get the `run_id` for the 4th execution (index 3) and the data version for `processed_data` before retrieving its value." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.23\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.67\n" + ] + } + ], + "source": [ + "run_id = interactive_dr.cache.run_ids[3]\n", + "data_version = interactive_dr.cache.data_versions[run_id][\"processed_data\"]\n", + "result = interactive_dr.cache.result_store.get(data_version)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Decoding the `cache_key`\n", + "\n", + "By now, you should have a better grasp on how Hamilton's caching determines when to execute a node. Internally, it creates a `cache_key` from the `code_version` of the node and the `data_version` of each dependency. The cache keys are stored on the `Driver.cache` and can be decoded for introspection and debugging.\n", + "\n", + "Here, we get the `run_id` for the 3rd execution (index 2) and the cache key for `amount_per_country`. We then use `decode_key()` to retrieve the `node_name`, `code_version`, and `dependencies_data_versions`." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'node_name': 'amount_per_country',\n", + " 'code_version': 'c2ccafa54280fbc969870b6baa445211277d7e8cfa98a0821836c175603ffda2',\n", + " 'dependencies_data_versions': {'processed_data': 'WgV5-4SfdKTfUY66x-msj_xXsKNPNTP2guRhfw=='}}" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from hamilton.caching.cache_key import decode_key\n", + "\n", + "run_id = interactive_dr.cache.run_ids[2]\n", + "cache_key = interactive_dr.cache.cache_keys[run_id][\"amount_per_country\"]\n", + "decode_key(cache_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Indeed, this match the data version for `processed_data` for the 3rd execution." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'WgV5-4SfdKTfUY66x-msj_xXsKNPNTP2guRhfw=='" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactive_dr.cache.data_versions[run_id][\"processed_data\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Manually retrieve metadata\n", + "\n", + "In addition to the `result_store`, there is a `metadata_store` that contains mapping between `cache_key` and `data_version` (cache keys are unique, but many can point to the same data).\n", + "\n", + "Using the knowledge from the previous section, we can use the cache key for `amount_per_country` to retrieve its `data_version` and result. It's also possible to decode its `cache_key`, and get the `data_version` for its dependencies, making the node execution reproducible." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "country\n", + "Canada 526.9194\n", + "USA 1719.2400\n", + "Name: amound_in_usd, dtype: float64\n" + ] + } + ], + "source": [ + "run_id = interactive_dr.cache.run_ids[2]\n", + "cache_key = interactive_dr.cache.cache_keys[run_id][\"amount_per_country\"]\n", + "amount_data_version = interactive_dr.cache.metadata_store.get(cache_key)\n", + "amount_result = interactive_dr.cache.result_store.get(amount_data_version)\n", + "print(amount_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "processed_data\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.23\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.67\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.34\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "\n" + ] + } + ], + "source": [ + "for dep_name, dependency_data_version in decode_key(cache_key)[\"dependencies_data_versions\"].items():\n", + " dep_result = interactive_dr.cache.result_store.get(dependency_data_version)\n", + " print(dep_name)\n", + " print(dep_result)\n", + " print()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/how-tos/index.rst b/docs/how-tos/index.rst index 565b84ac4..6eb8eb68e 100644 --- a/docs/how-tos/index.rst +++ b/docs/how-tos/index.rst @@ -10,12 +10,12 @@ directory. If there's an example you want but don't see, reach out or open an is .. toctree:: use-in-jupyter-notebook load-data + caching-tutorial use-for-feature-engineering ml-training llm-workflows run-data-quality-checks use-hamilton-for-lineage - cache-nodes scale-up microservice extensions-autoloading diff --git a/docs/index.md b/docs/index.md index 81855aba6..fbe8ae77a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -29,6 +29,7 @@ Slack `_ +- when hitting the base case of ``fingerprinting.hash_value()`` we return the constant ``UNHASHABLE_VALUE``. If the adapter receives this value, it will append a random UUID to it. This is to prevent collision between unhashable types. This ``data_version`` is no longer deterministic, but the value can still be retrieved or be part of another node's ``cache_key``. +- having ``@functools.singledispatch(object)`` allows to override the base case of ``hash_value()`` because it will catch all types. diff --git a/docs/reference/caching/data-versioning.rst b/docs/reference/caching/data-versioning.rst new file mode 100644 index 000000000..567c5eb6b --- /dev/null +++ b/docs/reference/caching/data-versioning.rst @@ -0,0 +1,6 @@ +================= +Data versioning +================= + +.. automodule:: hamilton.caching.fingerprinting + :members: diff --git a/docs/reference/caching/index.rst b/docs/reference/caching/index.rst new file mode 100644 index 000000000..dd0803d02 --- /dev/null +++ b/docs/reference/caching/index.rst @@ -0,0 +1,13 @@ +============== +Caching +============== + +Reference +--------- + +.. toctree:: + :maxdepth: 2 + + caching-logic + data-versioning + stores diff --git a/docs/reference/caching/stores.rst b/docs/reference/caching/stores.rst new file mode 100644 index 000000000..7e1a83d22 --- /dev/null +++ b/docs/reference/caching/stores.rst @@ -0,0 +1,21 @@ +========= +Stores +========= + +stores.base +----------- + +.. automodule:: hamilton.caching.stores.base + :members: + +stores.file +----------- + +.. automodule:: hamilton.caching.stores.file + :members: + +stores.sqlite +------------- + +.. automodule:: hamilton.caching.stores.sqlite + :members: diff --git a/examples/caching/README.md b/examples/caching/README.md new file mode 100644 index 000000000..4b53b2218 --- /dev/null +++ b/examples/caching/README.md @@ -0,0 +1,6 @@ +# Caching + +This directory contains tutorial notebooks for the Hamilton caching feature. + +- `tutorial.ipynb`: the main tutorial for caching +- `materializer_tutorial.ipynb`: tutorial on the interactions between `DataLoader/DataSaver` and caching. This is a more advanced tutorial for materializer users. You should complete the `tutorial.ipynb` first. diff --git a/examples/caching/materializer_tutorial.ipynb b/examples/caching/materializer_tutorial.ipynb new file mode 100644 index 000000000..9fd20f97f --- /dev/null +++ b/examples/caching/materializer_tutorial.ipynb @@ -0,0 +1,2842 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "# execute to install Python dependencies\n", + "# %%capture\n", + "# !pip install sf-hamilton[visualization] ipywidgets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Caching + materializers tutorial\n", + "\n", + "This notebook is a companion tutorial to the **Hamilton caching tutorial** notebook, which introduces caching more broadly.\n", + "\n", + "Its **Materializers** section teaches about different usage patterns for caching + materializers and introduces the `default_loader_behavior` and `default_saver_behavior` parameters. This notebook will show how to control loader and saver behaviors granularly.\n", + "\n", + "## Use cases\n", + "\n", + "As a reminder, here are some potential usage patterns\n", + "\n", + "**Loading data is expensive**: Your dataflow uses a `DataLoader` to get data from Snowflake. You want to load it once and cache it. When executing your dataflow, you want to use your cached copy to save query time, egress costs, etc.\n", + "- Use the `DEFAULT` caching behavior for loaders.\n", + "\n", + "**Only save new data**: You run the dataflow multiple times (maybe with different parameters or on a schedule) and only want to write to destination when the data changes.\n", + "- Use the `DEFAULT` caching behavior for savers.\n", + "\n", + "**Always read the latest data**: You want to use caching, but also ensure the dataflow always uses the latest data. This involves executing the `DataLoader` every time, get the data in-memory, version it, and then determine what needs to be executed (see [Changing external data](#changing-external-data)).\n", + "- Use the `RECOMPUTE` caching behavior for loaders.\n", + "\n", + "> NOTE. Caching + materializers is actively being improved so default behaviors and low-level APIs might change. This is a very powerful combo. If you have ideas, questions, or use cases, please reach out on Slack!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set up\n", + "The next cell sets up the notebook by:\n", + "- loading the Hamilton notebook extension\n", + "- getting the caching logger\n", + "- removing any existing cache directory" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import shutil\n", + "\n", + "from hamilton import driver\n", + "\n", + "CACHE_DIR = \"./.materializer_and_caching_cache\"\n", + "\n", + "logger = logging.getLogger(\"hamilton.caching\")\n", + "logger.setLevel(logging.INFO)\n", + "logger.addHandler(logging.StreamHandler())\n", + "\n", + "shutil.rmtree(CACHE_DIR, ignore_errors=True)\n", + "\n", + "# load the notebook extension\n", + "%reload_ext hamilton.plugins.jupyter_magic" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TL;DR\n", + "\n", + "Before diving into the details, here are the high-level ideas\n", + "\n", + "Materializers are available in several flavors:\n", + "- `@dataloader` and `@datasaver`: create a custom function to serve as `DataLoader` or `DataSaver`.\n", + "- `@load_from`: use a decorator to specify that an argument should be loaded from an external source\n", + "- `@save_to`: create a node that saves the output of the decorated node\n", + "- `from_` and `to`: equivalent to `@load_from` and `@save_to` but at the `Driver`-level\n", + "\n", + "When materializers and caching interact, it's important to realize the following:\n", + "- `@dataloader` and `@datasaver` are just like any other nodes and you can use `@cache` and `.with_cache()` as usual.\n", + "- `@load_from` and `@save_to` create nodes dynamically, so you there's no loader/saver function to apply `@cache` to directly. Instead, you add `@cache` to the function that has the `@load_from`/`@save_to` decorator. Also, you need to specify the name of internal nodes in `.with_cache()`, which can be trickier\n", + "- `from_` and `to` can't be decorated with `@cache` because they're defined at the `Driver`-level. Defining \"static\" materializers using `.with_materializers()` and `.with_cache()` is more intuitive. If you're using `Driver.materialize()` with \"dynamic\" materializers, you can still use `.with_cache()`. It can be more odd because `.with_cache()` will define behaviors for nodes that don't exist yet.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `@dataloader` and `@datasaver`\n", + "### Dataflow-level\n", + "\n", + "Let's rewrite the dataflow with the `@dataloader` and `@datasaver` decorators.\n", + "\n", + "- **DataLoader**: the function `raw_data()` now returns a `tuple` of `(result, metadata)`. The tuple type annotation needs to specify that `raw_data` returns a `pd.DataFrame` as the first element.\n", + "- **DataSaver**: the function `saved_data()` was added. It receives `amount_per_country()` and saves it to a parquet file. It must return a dictionary, which can contain metadata.\n", + "\n", + "Using the `@cache` decorator with `raw_data` or `saved_data` will apply the behavior to all associated materialization nodes.\n", + "\n", + "> NOTE. the `@cache` decorator can be above or below the `@dataloader` / `@datasaver` decorator." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module dataloader_dataflow_module -d\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import dataloader, datasaver, cache\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@cache(behavior=\"recompute\")\n", + "@dataloader()\n", + "def raw_data() -> tuple[pd.DataFrame, dict]:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " data = pd.DataFrame(DATA)\n", + " metadata = {\"source\": \"notebook\", \"format\": \"json\"}\n", + " return data, metadata\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()\n", + "\n", + "@cache(behavior=\"recompute\")\n", + "@datasaver()\n", + "def saved_data(amount_per_country: pd.DataFrame) -> dict:\n", + " amount_per_country.to_parquet(\"./saved_data.parquet\")\n", + " metadata = {\"source\": \"notebook\", \"format\": \"parquet\"}\n", + " return metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The visualization now displays the \"materializer\" node for the data loader. When we execute the dataflow twice and see that both `raw_data` and the associated `raw_data.loader` are recomputed." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data.loader::adapter::execute_node\n", + "raw_data.loader::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "saved_data::adapter::execute_node\n", + "saved_data::adapter::execute_node\n", + "raw_data.loader::adapter::execute_node\n", + "raw_data.loader::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "saved_data::adapter::execute_node\n", + "saved_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataloader_dataflow_dr = (\n", + " driver.Builder()\n", + " .with_modules(dataloader_dataflow_module)\n", + " .with_cache(path=CACHE_DIR)\n", + " .build()\n", + ")\n", + "\n", + "dataloader_dataflow_results = dataloader_dataflow_dr.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "dataloader_dataflow_results = dataloader_dataflow_dr.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(dataloader_dataflow_results[\"amount_per_country\"].head())\n", + "print()\n", + "dataloader_dataflow_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can manually inspect the node caching behavior via the `Driver.cache`. The `RECOMPUTE` behavior applied to `raw_data` is also applied to the internal `raw_data.loader`." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data.loader': ,\n", + " 'raw_data': ,\n", + " 'saved_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataloader_dataflow_dr.cache.behaviors[dataloader_dataflow_dr.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Driver-level\n", + "Now, let's specify the behavior at the `Driver`-level instead. The next cell contains the same module, but without the `@cache` decorator." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module dataloader_driver_module\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import dataloader, datasaver\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@dataloader()\n", + "def raw_data() -> tuple[pd.DataFrame, dict]:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " data = pd.DataFrame(DATA)\n", + " metadata = {\"source\": \"notebook\", \"format\": \"json\"}\n", + " return data, metadata\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()\n", + "\n", + "@datasaver()\n", + "def saved_data(amount_per_country: pd.DataFrame) -> dict:\n", + " amount_per_country.to_parquet(\"./saved_data.parquet\")\n", + " metadata = {\"source\": \"notebook\", \"format\": \"parquet\"}\n", + " return metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When building the `Driver`, we use `.with_cache(recompute=[\"raw_data\", \"saved_data\"])` to specify the nodes behavior." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data.loader::adapter::execute_node\n", + "raw_data.loader::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::adapter::execute_node\n", + "saved_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataloader_driver_dr = (\n", + " driver.Builder()\n", + " .with_modules(dataloader_driver_module)\n", + " .with_cache(\n", + " path=CACHE_DIR,\n", + " recompute=[\"raw_data\", \"saved_data\"],\n", + " )\n", + " .build()\n", + ")\n", + "\n", + "dataloader_driver_results = dataloader_driver_dr.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(dataloader_driver_results[\"amount_per_country\"].head())\n", + "print()\n", + "dataloader_driver_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `RECOMPUTE` behavior applied to `raw_data` is also applied to the internal `raw_data.loader`." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data.loader': ,\n", + " 'raw_data': ,\n", + " 'saved_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataloader_driver_dr.cache.behaviors[dataloader_driver_dr.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `@load_from` and `@save_to`\n", + "### Dataflow-level\n", + "\n", + "Using `@load_from` and `@save_to` respectively remove the need to have the `raw_data()` and `saved_data()` functions. Instead, the loader/saver nodes are created a runtime, meaning we can't directly decorate them with `@cache`.\n", + "\n", + "> At the time of release, the `@cache` decorator must be **under** the `@load_from` or `@save_to`. This quirk will be fixed because order shouldn't matter. \n", + "\n", + "The `@cache` decorator will be applied to `processed_data` and `amount_per_country`. By default, this will apply the behavior both to the loader node `raw_data`, but also `processed_data`. Similarly, `amount_per_country` and the generated `save.amount_per_country` will receive the behavior." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data\n", + "\n", + "processed_data.select_data.raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data->processed_data.select_data.raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "save.amount_per_country\n", + "\n", + "\n", + "save.amount_per_country\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country->save.amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module load_from_dataflow_module -d\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import load_from, save_to, cache\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@load_from.parquet(path=\"raw_data.parquet\", inject_=\"raw_data\")\n", + "@cache(behavior=\"recompute\")\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "@save_to.parquet(path=\"saved_data.parquet\")\n", + "@cache(behavior=\"recompute\")\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The visualization displays the internal nodes generated by `@load_from` and `@save_to`." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "processed_data.load_data.raw_data::adapter::execute_node\n", + "processed_data.load_data.raw_data::adapter::execute_node\n", + "processed_data.select_data.raw_data::adapter::execute_node\n", + "processed_data.select_data.raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "save.amount_per_country::adapter::execute_node\n", + "save.amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.957\n", + "Mexico 46.217\n", + "USA 2959.76\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data\n", + "\n", + "processed_data.select_data.raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "save.amount_per_country\n", + "\n", + "\n", + "save.amount_per_country\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data->processed_data.select_data.raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country->save.amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "load_from_dataflow_dr = (\n", + " driver.Builder()\n", + " .with_modules(load_from_dataflow_module)\n", + " .with_cache(path=CACHE_DIR)\n", + " .build()\n", + ")\n", + "\n", + "load_from_dataflow_results = load_from_dataflow_dr.execute(\n", + " [\"amount_per_country\", \"save.amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(load_from_dataflow_results[\"amount_per_country\"].head())\n", + "print()\n", + "load_from_dataflow_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As expected, most nodes receive the `RECOMPUTE` behavior in this case. Note that both internal nodes `processed_data.load_data.raw_data` and `processed_data.select_data.raw_data` receive the behavior." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'save.amount_per_country': ,\n", + " 'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'processed_data.load_data.raw_data': ,\n", + " 'processed_data.select_data.raw_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "load_from_dataflow_dr.cache.behaviors[load_from_dataflow_dr.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Granular control" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the previous cells, using `@cache` applied the behavior to all the nodes associated with the function decorated by `@load_from` or `@save_to`.\n", + "\n", + "To achieve granular control, we can use the `target_` parameter of the `@cache` decorator where you can specify the name of the generated nodes.\n", + "\n", + "For `@load_from`, we will want to target `processed_data.load_data.raw_data`. Generally, this node name has the form `f\"{main_node}.load_data.{loaded_node}\"`. In complex scenarios, you should also add `f\"{main_node}.select_data.{loaded_node}\"` to the `target_` parameter for extra safety.\n", + "\n", + "For `@save_to`, we will want to target `save.amount_per_country`. The generic node name is `f\"save.{main_node}\"`." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data\n", + "\n", + "processed_data.select_data.raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data->processed_data.select_data.raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "save.amount_per_country\n", + "\n", + "\n", + "save.amount_per_country\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country->save.amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module load_from_granular_module -d\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import load_from, save_to, cache\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@load_from.parquet(path=\"raw_data.parquet\", inject_=\"raw_data\")\n", + "@cache(behavior=\"recompute\", target_=\"processed_data.load_data.raw_data\")\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "@save_to.parquet(path=\"saved_data.parquet\")\n", + "@cache(behavior=\"recompute\", target_=\"save.amount_per_country\")\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "processed_data.load_data.raw_data::adapter::execute_node\n", + "processed_data.load_data.raw_data::adapter::execute_node\n", + "processed_data.select_data.raw_data::adapter::execute_node\n", + "processed_data.select_data.raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "save.amount_per_country::adapter::execute_node\n", + "save.amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.957\n", + "Mexico 46.217\n", + "USA 2959.76\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data\n", + "\n", + "processed_data.select_data.raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "save.amount_per_country\n", + "\n", + "\n", + "save.amount_per_country\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data->processed_data.select_data.raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country->save.amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "load_from_granular_dr = (\n", + " driver.Builder()\n", + " .with_modules(load_from_granular_module)\n", + " .with_cache(path=CACHE_DIR)\n", + " .build()\n", + ")\n", + "\n", + "load_from_granular_results = load_from_granular_dr.execute(\n", + " [\"amount_per_country\", \"save.amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(load_from_granular_results[\"amount_per_country\"].head())\n", + "print()\n", + "load_from_granular_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we see that the nodes decorated with `@load_from` and `@save_to` (`processed_data` and `amount_per_country`) don't receive the behavior specified in `@cache`." + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'save.amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'processed_data.select_data.raw_data': ,\n", + " 'processed_data.load_data.raw_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "load_from_granular_dr.cache.behaviors[load_from_granular_dr.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Driver-level\n", + "\n", + "The next cell presents the same module as before, but without the `@cache` decorator." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data\n", + "\n", + "processed_data.select_data.raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data->processed_data.select_data.raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "save.amount_per_country\n", + "\n", + "\n", + "save.amount_per_country\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country->save.amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module load_from_driver_module -d\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import load_from, save_to\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@load_from.parquet(path=\"raw_data.parquet\", inject_=\"raw_data\")\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "@save_to.parquet(path=\"saved_data.parquet\")\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the `.with_cache()` clause, we don't have to specify the loader's internal names; we can simply use `\"raw_data\"`. For the saver, we must use `\"save.amount_per_country\"` because this matches the name we need to pass to `Driver.execute()`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "processed_data.load_data.raw_data::adapter::execute_node\n", + "processed_data.load_data.raw_data::adapter::execute_node\n", + "processed_data.select_data.raw_data::adapter::execute_node\n", + "processed_data.select_data.raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "save.amount_per_country::adapter::execute_node\n", + "save.amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.957\n", + "Mexico 46.217\n", + "USA 2959.76\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data\n", + "\n", + "processed_data.select_data.raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data.select_data.raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "save.amount_per_country\n", + "\n", + "\n", + "save.amount_per_country\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "\n", + "\n", + "processed_data.load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "processed_data.load_data.raw_data->processed_data.select_data.raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country->save.amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "load_from_driver_dr = (\n", + " driver.Builder()\n", + " .with_modules(load_from_driver_module)\n", + " .with_cache(\n", + " path=CACHE_DIR,\n", + " recompute=[\"raw_data\", \"save.amount_per_country\"]\n", + " )\n", + " .build()\n", + ")\n", + "\n", + "load_from_driver_results = load_from_driver_dr.execute(\n", + " [\"amount_per_country\", \"save.amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(load_from_driver_results[\"amount_per_country\"].head())\n", + "print()\n", + "load_from_driver_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The internal nodes associated with `raw_data` have the right behavior. It's generally easier to use than combining `@cache` and `@load_from`/`@save_to`." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'save.amount_per_country': ,\n", + " 'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'processed_data.load_data.raw_data': ,\n", + " 'processed_data.select_data.raw_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "load_from_driver_dr.cache.behaviors[load_from_driver_dr.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## `from_.` and `to.`\n", + "The constructs `from_` and `to` are ways of defining `DataLoader` and `DataSaver` objects at the `Driver`-level. Like the previous cells, there is no `raw_data()` or `saved_data()` nodes, but no `@load_from` & `@save_to` decorators either.\n", + "\n", + "Notice in the module visualization that `raw_data` now appears as an \"input\" and the saver node is absent." + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "raw_data\n", + "DataFrame\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module from_module -d\n", + "import pandas as pd\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are two ways to use `from_` and `to`:\n", + "- via \"static\" materializers added to the `Driver` using `Builder.with_materializers()`\n", + "- via \"dynamic\" materializers passed to `Driver.materialize()` (similar to `Driver.execute()`)\n", + "\n", + "In both cases, it will work with the `.with_cache(recompute=...)` clause." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `.with_materializers()`\n", + "Here, we use `.with_materializers()` to add a parquet loader for `raw_data` and a parquet saver for `amount_per_country`. Note that in `to.parquet(id=...)`, the `id` will be the node name of the data saver.\n", + "\n", + "Then, we add to `.with_cache(recompute=[...])` the node names `raw_data` and `saved_data` (the saver `id`) \n", + "\n", + "We call them \"static\" materializers because they're attached to the `Driver`, can be visualized, and called directly via `.execute()`" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "load_data.raw_data\n", + "\n", + "\n", + "load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "load_data.raw_data->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from hamilton.io.materialization import from_, to\n", + "\n", + "static_from_dr = (\n", + " driver.Builder()\n", + " .with_modules(from_module)\n", + " .with_materializers(\n", + " from_.parquet(path=\"raw_data.parquet\", target=\"raw_data\"),\n", + " to.parquet(\n", + " id=\"saved_data\",\n", + " dependencies=[\"amount_per_country\"],\n", + " path=\"saved_data.parquet\",\n", + " )\n", + " )\n", + " .with_cache(\n", + " path=CACHE_DIR,\n", + " recompute=[\"raw_data\",\"saved_data\"],\n", + " default_loader_behavior=\"disable\",\n", + " )\n", + " .build()\n", + ")\n", + "static_from_dr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We execute the dataflow using `.execute()` and requesting the data saver's name `saved_data`." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "load_data.raw_data::adapter::execute_node\n", + "load_data.raw_data::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::adapter::execute_node\n", + "saved_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "load_data.raw_data\n", + "\n", + "\n", + "load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "load_data.raw_data->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "static_from_results = static_from_dr.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(static_from_results[\"amount_per_country\"].head())\n", + "print()\n", + "static_from_dr.cache.view_run()" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data': ,\n", + " 'cutoff_date': ,\n", + " 'saved_data': ,\n", + " 'load_data.raw_data': }" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "static_from_dr.cache.behaviors[static_from_dr.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `.materialize()`\n", + "Now, we build a `Driver` without the static materializers. Just like the dataflow definition, the module will show `raw_data` as an input." + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "raw_data\n", + "DataFrame\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dynamic_from_dr = (\n", + " driver.Builder()\n", + " .with_modules(from_module)\n", + " .with_cache(\n", + " path=CACHE_DIR,\n", + " recompute=[\"raw_data\", \"saved_data\"]\n", + " )\n", + " .build()\n", + ")\n", + "dynamic_from_dr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The method `Driver.materialize()` has a slightly different signature than `Driver.execute()`. - The first argument collects `DataLoader` and `DataSaver` objects\n", + "- `additional_vars` is equivalent to `final_vars` in `Driver.execute()`\n", + "- it returns a tuple of `(metadata, additional_vars_results)` " + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "load_data.raw_data::adapter::execute_node\n", + "load_data.raw_data::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::adapter::execute_node\n", + "saved_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "load_data.raw_data\n", + "\n", + "\n", + "load_data.raw_data\n", + "PandasParquetReader\n", + "\n", + "\n", + "\n", + "load_data.raw_data->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "PandasParquetWriter\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata, dynamic_from_results = dynamic_from_dr.materialize(\n", + " from_.parquet(path=\"raw_data.parquet\", target=\"raw_data\"),\n", + " to.parquet(\n", + " id=\"saved_data\",\n", + " dependencies=[\"amount_per_country\"],\n", + " path=\"saved_data.parquet\",\n", + " ),\n", + " additional_vars=[\"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(dynamic_from_results[\"amount_per_country\"].head())\n", + "print()\n", + "dynamic_from_dr.cache.view_run()" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data': ,\n", + " 'cutoff_date': ,\n", + " 'saved_data': ,\n", + " 'load_data.raw_data': }" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dynamic_from_dr.cache.behaviors[dynamic_from_dr.cache.last_run_id]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/caching/raw_data.parquet b/examples/caching/raw_data.parquet new file mode 100644 index 0000000000000000000000000000000000000000..6e844e8bfc8160c620e81df5235595dd166708cc GIT binary patch literal 4294 zcmcgwO>84c6>d8jC*#e8XympWDS%}$)}T#7ciWzgM>~?sPW-oycjDRjkD#Id61USo zUc1{r83io}Sdc~_aab*g16t;=i_nM*5+@GBg##xpNStQjfCL=4ATE30Rdu)1o_G$c z4Ycg)s`p;K_r0HbmH0lYvGgMSBR~CJe}xXw)J=-|zVBa$%M?Xb9_}0~>d!+t;VMBTJ5X zP`viKAI6Yx@!A^q*6@9p$Db^0&@1aHDzLU535bfR$OGCR4uzjAgnzmaf9j7fzf9fw zCWL~b{S*vq@S-SKf8wLR?W@o?Noc;23<8y7xa1a--eMEnDy;CpZm&{{Od3yj@bDTQ z?!d5<^6p@gb36-P=!;C}OcD%vyo=@Jcd>WR0~g*cO1vo(-`-mYf8dY*Jj%T2Z)Dk37+H4r zgsIzesI&gN*?k_|po4C9dCjnNlSJUp{_t~u{Mq&RJIfoN!p7W)l<7)^UM2~@m018B zhYwHyUO8UpCEgc&-7aEG%J%s)MFl4#^u|oNF2J5hz-EKo%BEeX{%e_Uu2S)T-nhu( z5}X(~LKEV~Ghg`kzWC!C@h2-dkuzH6V%?$dx>z?NHr8F=QYMFX@nXMO#jk9wG9Th@ zWC3`W@%?h`0^hP|_4~3e+5msy3;)U&|KsQ4pR76n9Ub$~t*a?25*UabS>x#-MX%iQ zyGQaCRoMW$ghcAtB?PgIomiXe#>W-FN~V|`VmP$Ev3_)=NB^o%TKs=rxmQ(H@hi!G z_??w-D83wi@0Acmt#3vGJzkgiLF7MyM=?d0NuJlry| z&1g*LH5mxHrK+$Rn0()aB}3*$52ma=m&yG6YrwD-J|x;@NAIkfV?|b_hqE*r*F$sC z!}em94)>E}L$I{oWXdpvt}L1unkJjj2TUHsAI0Eg7?Pr&;Jt@$<+pZX=N>S#C}Q?? zgmugy@z6gL8XT`>EBJX8?IDB?;ESI(5K5x# z3CbgMf^sK>md(!aN0ht=dM+IoliF(U>%MvHQf`mnkpCSXV^|7F;E08 zZ?+%6-a?8M3Zp?OSJwDOebCOQtyY6ov7DgS2SN_?=vlx8w%IfzH?xFS((4xAs1DnO zLj_|2SyL>k$5K%p5xu$cu%-3XR^q_sH|k8QQR_&BJ?s<2qi(gh^eOnR6W`VPVNGF@ zC7qqcei5fBF7T@-6;xH6l;+}of$!v^>cz$in%Wh(nj33YXl7H*L@gzPKM-F<;L2~e z8&$^61*=MhJmj<1nZ~8a3`#{$9CKyNcl5P;?Z)`D$+X04Dbx#dw?Sa7MnQC`# zj!E4fLmfhmdviCG8YxE9PhWb|cRSiIQ_(#=REw2OKwC zpK#V-ULH#J8Q|K@@)_SqTjVTt_t;7oS^G*kP5Q(%w{2C_t382B5*_DRBnF4y^a`xDh+5 z$&?5otSURrIVxcJS5ILe_RyF}5j-Ay1x$rLMQVviK;O+F1a#bg@l!$^wvXHO!6E6~ zWYhF2aPSv!PdX$!CRiTqK`_9d7Ibk04V?6K_x^a9|27^@WuS+ixm!GV?@^IxkMF;2 j$>x^W>ur%sF6vxjQTuvZKZ5@lydO$ffu9!O7mWV`=LMt; literal 0 HcmV?d00001 diff --git a/examples/caching/requirements.txt b/examples/caching/requirements.txt new file mode 100644 index 000000000..e90acb2ce --- /dev/null +++ b/examples/caching/requirements.txt @@ -0,0 +1,3 @@ +ipywidgets +pandas +sf-hamilton[visualization] diff --git a/examples/caching/tutorial.ipynb b/examples/caching/tutorial.ipynb new file mode 100644 index 000000000..54aefa9a7 --- /dev/null +++ b/examples/caching/tutorial.ipynb @@ -0,0 +1,4054 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# execute to install Python dependencies\n", + "# %%capture\n", + "# !pip install sf-hamilton[visualization] ipywidgets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Hamilton caching tutorial\n", + "In Hamilton, **caching** broadly refers to \"reusing results from previous executions to skip redundant computation\". If you change code or pass new data, it will automatically determine which results can be reused and which nodes need to be re-executed. This improves execution speed and reduces resource usage (computation, API credits, etc.).\n", + "\n", + "## Table of contents\n", + "- [Basics](#basics)\n", + " - [Understanding the `cache_key`](#understanding-the-cache_key)\n", + "- [Adding a node](#adding-a-node)\n", + "- [Changing inputs](#changing-inputs)\n", + "- [Changing code](#changing-code)\n", + "- [Changing external data](#changing-external-data)\n", + " - [Idempotency](#idempotency)\n", + " - [`.with_cache()` to specify caching behavior](#with_cache-to-specify-caching-behavior)\n", + " - [`@cache` to specify caching behavior](#cache-to-specify-caching-behavior)\n", + " - [When to use `@cache` vs `.with_cache()`](#when-to-use-cache-vs-with_cache)\n", + "- [Force recompute all](#force-recompute-all)\n", + "- [Setting default behavior](#setting-default-behavior)\n", + "- [Materializers](#materializers)\n", + " - [Usage patterns](#usage-patterns)\n", + "- [Changing the cache format](#changing-the-cache-format)\n", + "- [Introspecting the cache](#introspecting-the-cache)\n", + "- [Managing storage](#managing-storage)\n", + " - [Setting the cache path](#setting-the-cache-path)\n", + " - [Instantiating the result_store and metadata_store](#instantiating-the-result_store-and-metadata_store)\n", + " - [Deleting data and recovering storage](#deleting-data-and-recovering-storage)\n", + "- [Usage patterns](#usage-patterns)\n", + "- 🚧 INTERNALS\n", + " - [Manually retrieve results](#manually-retrieve-results)\n", + " - [Decoding the cache_key](#decoding-the-cache_key)\n", + " - [Manually retrieve metadata](#manually-retrieve-metadata)\n", + "\n", + "\n", + "> NOTE. This notebook is on the longer side. We highly suggest using the navigation bar to help." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Throughout this tutorial, we'll be using the Hamilton notebook extension to define dataflows directly in the notebook ([see tutorial](https://github.com/DAGWorks-Inc/hamilton/blob/main/examples/jupyter_notebook_magic/example.ipynb)).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton import driver\n", + "\n", + "# load the notebook extension\n", + "%reload_ext hamilton.plugins.jupyter_magic" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We import the `logging` module and get the logger from `hamilton.caching`. With the level set to ``INFO``, we'll see ``GET_RESULT`` and ``EXECUTE_NODE`` cache events as they happen." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "logger = logging.getLogger(\"hamilton.caching\")\n", + "logger.setLevel(logging.INFO)\n", + "logger.addHandler(logging.StreamHandler())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The next cell deletes the cached data to ensure this notebook can be run from top to bottom without any issues." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "\n", + "shutil.rmtree(\"./.hamilton_cache\", ignore_errors=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basics\n", + "\n", + "Throughout this notebook, we'll use the same simple dataflow that processes transactions in various locations and currencies.\n", + "\n", + "We use the cell magic `%%cell_to_module` from the Hamilton notebook extension. It will convert the content of the cell into a Python module that can be loaded by Hamilton. The `--display` flag allows to visualize the dataflow." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module basics_module --display\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.73\n", + " return df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then, we build the ``Driver`` with caching enabled and execute the dataflow." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 541.7622\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 426.7288\n" + ] + } + ], + "source": [ + "basics_dr = driver.Builder().with_modules(basics_module).with_cache().build()\n", + "\n", + "basics_results_1 = basics_dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(basics_results_1[\"processed_data\"].head())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can view what values were retrieved from the cache using `dr.cache.view_run()`. Since this was the first execution, nothing is retrieved." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "basics_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On the second execution, `processed_data` is retrieved from cache as reported in the logs and highlighted in the visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 541.7622\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 426.7288\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "basics_results_2 = basics_dr.execute([\"processed_data\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(basics_results_2[\"processed_data\"].head())\n", + "print()\n", + "basics_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Understanding the `cache_key`\n", + "\n", + "The Hamilton cache stores results using a `cache_key`. It is composed of the node's name (`node_name`), the code that defines it (`code_version`), and its data inputs (`data_version` of its dependencies).\n", + "\n", + "For example, the cache keys for the previous cells are:\n", + "\n", + "```json\n", + "{\n", + " \"node_name\": \"raw_data\",\n", + " \"code_version\": \"9d727859b9fd883247c3379d4d25a35af4a56df9d9fde20c75c6375dde631c68\",\n", + " \"dependencies_data_versions\": {} // it has no dependencies\n", + "}\n", + "{\n", + " \"node_name\": \"processed_data\",\n", + " \"code_version\": \"c9e3377d6c5044944bd89eeb7073c730ee8707627c39906b4156c6411f056f00\",\n", + " \"dependencies_data_versions\": {\n", + " \"cutoff_date\": \"WkGjJythLWYAIj2Qr8T_ug==\", // input value\n", + " \"raw_data\": \"t-BDcMLikFSNdn4piUKy1mBcKPoEsnsYjUNzWg==\" // raw_data's result\n", + " }\n", + "}\n", + "```\n", + "\n", + "Results could be successfully retrieved because nodes in the first execution and second execution shared the same `cache_key`.\n", + "\n", + "The `cache_key` objects are internal and you won't have to interact with them directly. However, keep that concept in mind throughout this tutorial. Towards the end, we show how to manually handle the `cache_key` for debugging." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding a node\n", + "\n", + "Let's say you're iteratively developing your dataflow and you add a new node. Here, we copy the previous module into a new module named `adding_node_module` and define the node `amount_per_country`.\n", + "\n", + "> In practice, you would edit the cell directly, but this makes the notebook easier to read and maintain" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module adding_node_module --display\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.73\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We build a new `Driver` with `adding_node_module` and execute the dataflow. You'll notice that `raw_data` and `processed_data` are retrieved and only `amount_per_country` is executed." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 968.491\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adding_node_dr = driver.Builder().with_modules(adding_node_module).with_cache().build()\n", + "\n", + "adding_node_results = adding_node_dr.execute(\n", + " [\"processed_data\", \"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(adding_node_results[\"amount_per_country\"].head())\n", + "print()\n", + "adding_node_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Even though this is the first execution of `adding_node_dr` and the module `adding_node_module`, the cache contains results for `raw_data` and `processed_data`. We're able to retrieve values because they have the same cache keys (code version and dependencies data versions).\n", + "\n", + "This means you can reuse cached results across dataflows. This is particularly useful with training and inference machine learning pipelines." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing inputs\n", + "\n", + "We reuse the same dataflow `adding_node_module`, but change the input `cutoff_date` from\n", + "`\"2024-09-01\"` to `\"2024-09-11\"`. \n", + "\n", + "\n", + "This new input forces `processed_data` to be re-executed. This produces a new result for `processed_data`, which cascades and also forced `amount_per_country` to be re-executed." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "USA 729.9\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_inputs_dr = driver.Builder().with_modules(adding_node_module).with_cache().build()\n", + "\n", + "changing_inputs_results_1 = changing_inputs_dr.execute(\n", + " [\"processed_data\", \"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-11\"}\n", + ")\n", + "print()\n", + "print(changing_inputs_results_1[\"amount_per_country\"].head())\n", + "print()\n", + "changing_inputs_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we execute with the `cutoff_date` value `\"2024-09-05\"`, which forces `processed_data` to be executed." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 968.491\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_inputs_results_2 = changing_inputs_dr.execute(\n", + " [\"processed_data\", \"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-05\"}\n", + ")\n", + "print()\n", + "print(changing_inputs_results_2[\"amount_per_country\"].head())\n", + "print()\n", + "changing_inputs_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the cache could still retrieve `amount_per_country`. This is because `processed_data` return a value that had been cached previously (in the `Adding a node` section).\n", + "\n", + "In concrete terms, filtering rows by the date `\"2024-09-05\"` or `\"2024-09-01\"` includes the same rows and produces the same dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 541.7622\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 426.7288\n", + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 541.7622\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 426.7288\n" + ] + } + ], + "source": [ + "print(adding_node_results[\"processed_data\"])\n", + "print()\n", + "print(changing_inputs_results_2[\"processed_data\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing code\n", + "As you develop your dataflow, you will need to edit upstream nodes. Caching will automatically detect code changes and determine which node needs to be re-executed. In `processed_data()`, we'll change the conversation rate from `0.73` to `0.71`.\n", + "\n", + "> NOTE. changes to docstrings and comments `#` are ignored when versioning a node." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module changing_code_module\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 # <- VALUE CHANGED FROM module_2\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We need to execute `processed_data` because the code change created a new `cache_key` and led to a cache miss. Then, `processed_data` returns a previously unseen value, forcing `amount_per_country` to also be re-executed" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 941.957\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_code_dr_1 = driver.Builder().with_modules(changing_code_module).with_cache().build()\n", + "\n", + "changing_code_results_1 = changing_code_dr_1.execute(\n", + " [\"processed_data\", \"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(changing_code_results_1[\"amount_per_country\"].head())\n", + "print()\n", + "changing_code_dr_1.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We make another code change to `processed_data` to accomodate currency conversion for Brazil and Mexico." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module changing_code_module_2\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18 # <- LINE ADDED\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05 # <- LINE ADDED\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again, the code change forces `processed_data` to be executed." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 941.957\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_code_dr_2 = driver.Builder().with_modules(changing_code_module_2).with_cache().build()\n", + "\n", + "changing_code_results_2 = changing_code_dr_2.execute([\"processed_data\",\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(changing_code_results_2[\"amount_per_country\"].head())\n", + "print()\n", + "changing_code_dr_2.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, `amount_per_country` can be retrieved because `processed_data` returned a previously seen value.\n", + "\n", + "In concrete terms, adding code to process currency from Brazil and Mexico didn't change the `processed_data` result because it only includes data from the USA and Canada.\n", + "\n", + "> NOTE. This is similar to what happened at the end of the section **Changing inputs**." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n", + "\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.2300\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.6700\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.3400\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "4 Vancouver 2024-09-09 584.56 Canada CAD 415.0376\n" + ] + } + ], + "source": [ + "print(changing_code_results_1[\"processed_data\"])\n", + "print()\n", + "print(changing_code_results_2[\"processed_data\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing external data\n", + "\n", + "Hamilton's caching mechanism uses the node's `code_version` and its dependencies `data_version` to determine if the node needs to be executed or the result can be retrieved from cache. By default, it assumes [idempotency](https://www.astronomer.io/docs/learn/dag-best-practices#review-idempotency) of operations.\n", + "\n", + "This section covers how to handle node with external effects, such as reading or writing external data.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Idempotency\n", + "\n", + "To illustrate idempotency, let's use this minimal dataflow which has a single node that returns the current date and time:\n", + "\n", + "```python\n", + "import datetime\n", + "\n", + "def current_datetime() -> datetime.datetime:\n", + " return datetime.datetime.now()\n", + "```\n", + "\n", + "The first execution will execute the node and store the resulting date and time. On the second execution, the cache will read the stored result instead of re-executing. Why? Because the `code_version` is the same and the dependencies `data_version` (it has no dependencies) haven't changed.\n", + "\n", + "A similar situation occurs when reading from external data, as shown here:\n", + "\n", + "```python\n", + "import pandas as pd\n", + "\n", + "def dataset(file_path: str) -> pd.DataFrame:\n", + " return pd.read_csv(file_path)\n", + "```\n", + "\n", + "Here, the code of `dataset()` and the value for `file_path` can stay the same, but the file itself could be updated (e.g., new rows added).\n", + "\n", + "The next sections show how to always re-execute a node and ensure the latest data is used. The `DATA` constant is modified with transactions in Brazil and Mexico to simulate `raw_data` loading a new dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module changing_external_module\n", + "import pandas as pd\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "At execution, we see `raw_data` being retrieved along with all downstream nodes. Also, we note that the printed results don't include Brazil nor Mexico." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Canada 941.957\n", + "USA 1719.240\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_external_dr = driver.Builder().with_modules(changing_external_module).with_cache().build()\n", + "\n", + "changing_external_results = changing_external_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(changing_external_results[\"amount_per_country\"].head())\n", + "print()\n", + "changing_external_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `.with_cache()` to specify caching behavior\n", + "Here, we build a new `Driver` with the same `changing_external_module`, but we specify in `.with_cache()` to always recompute `raw_data`. \n", + "\n", + "The visualization shows that `raw_data` was executed, and because of the new data, all downstream nodes also need to be executed. The results now include Brazil and Mexico." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_external_with_cache_dr = driver.Builder().with_modules(changing_external_module).with_cache(recompute=[\"raw_data\"]).build()\n", + "\n", + "changing_external_with_cache_results = changing_external_with_cache_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(changing_external_with_cache_results[\"amount_per_country\"].head())\n", + "print()\n", + "changing_external_with_cache_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `@cache` to specify caching behavior\n", + "Another way to specify the `RECOMPUTE` behavior is to use the `@cache` decorator." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module changing_external_decorator_module\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import cache\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@cache(behavior=\"recompute\")\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We build a new `Driver` with `changing_external_cache_decorator_module`, which includes the `@cache` decorator. Note that we don't specify anything in `.with_cache()`." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "changing_external_decorator_dr = (\n", + " driver.Builder()\n", + " .with_modules(changing_external_decorator_module)\n", + " .with_cache()\n", + " .build()\n", + ")\n", + "\n", + "changing_external_decorator_results = changing_external_decorator_dr.execute(\n", + " [\"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(changing_external_decorator_results[\"amount_per_country\"].head())\n", + "print()\n", + "changing_external_decorator_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that `raw_data` was re-executed. Then, `processed_data` and `amount_per_country` can be retrieved because they were produced just before by the `changing_external_with_cache_dr`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### When to use `@cache` vs. `.with_cache()`?\n", + "\n", + "Specifying the caching behavior via `.with_cache()` or `@cache` is entirely equivalent. There are benefits to either approach:\n", + "\n", + "- `@cache`: specify behavior at the dataflow-level. The behavior is tied to the node and will be picked up by all `Driver` loading the module. This can prevent errors or unexpected behaviors for users of that dataflow.\n", + "\n", + "- `.with_cache()`: specify behavior at the `Driver`-level. Gives the flexiblity to change the behavior without modifying the dataflow code and committing changes. You might be ok with `DEFAULT` during development, but want to ensure `RECOMPUTE` in production.\n", + "\n", + "Importantly, the behavior specified in `.with_cache(...)` overrides whatever is in `@cache` because it is closer to execution. For example, having `.with_cache(default=[\"raw_data\"])` `@cache(behavior=\"recompute\")` would force `DEFAULT` behavior.\n", + "\n", + "> ⛔ **Important**: Using the `@cache` decorator alone doesn't enable caching; adding `.with_cache()` to the `Builder` does. The decorator is only a mean to specify special behaviors for a node.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Force recompute all\n", + "By specifying `.with_cache(recompute=True)`, you are setting the behavior `RECOMPUTE` for all nodes. This forces recomputation, which is useful for producing a \"cache refresh\" with up-to-date values." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "recompute_all_dr = (\n", + " driver.Builder()\n", + " .with_modules(changing_external_decorator_module)\n", + " .with_cache(recompute=True)\n", + " .build()\n", + ")\n", + "\n", + "recompute_all_results = recompute_all_dr.execute(\n", + " [\"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(recompute_all_results[\"amount_per_country\"].head())\n", + "print()\n", + "recompute_all_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that all nodes were recomputed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting default behavior\n", + "\n", + "Once you enable caching using `.with_cache()`, it is a \"opt-out\" feature by default. This means all nodes are cached unless you set the `DISABLE` behavior via `@cache` or `.with_cache(disable=[...])`. This can become difficult to manage as the number of nodes increases. \n", + "\n", + "You can make it an \"opt-in\" feature by setting `default_behavior=\"disable\"` in `.with_cache()`. This way, you're using caching, but only for nodes explicitly specified in `@cache` or `.with_cache()`.\n", + "\n", + "Here, we build a `Driver` with the `changing_external_decorator_module`, where `raw_data` was set to have behavior `RECOMPUTE`, and set the default behavior to `DISABLE`." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::adapter::execute_node\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "default_behavior_dr = (\n", + " driver.Builder()\n", + " .with_modules(changing_external_decorator_module)\n", + " .with_cache(default_behavior=\"disable\")\n", + " .build()\n", + ")\n", + "\n", + "default_behavior_results = default_behavior_dr.execute(\n", + " [\"amount_per_country\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(default_behavior_results[\"amount_per_country\"].head())\n", + "print()\n", + "default_behavior_dr.cache.view_run()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "default_behavior_dr.cache.behaviors[default_behavior_dr.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Materializers\n", + "\n", + "> NOTE. You can skip this section if you're not using materializers.\n", + "\n", + "`DataLoader` and `DataSaver` (collectively \"materializers\") are special Hamilton nodes that connect your dataflow to external data (files, databases, etc.). These constructs are safe to use with caching and are complementary.\n", + "\n", + "**Caching**\n", + "- writing and reading shorter-term data to be used with the dataflow\n", + "- strong connection between the code and the data\n", + "- automatically handle multiple versions of the same dataset\n", + "\n", + "**Materializers**\n", + "- robust mechanism to read/write data from many sources\n", + "- data isn't necessarily meant to be used with Hamilton (e.g., loading from a warehouse, outputting a report).\n", + "- typically outputs to a static destination; each write overwrites the previous stored dataset.\n", + "\n", + "The next cell uses `@dataloader` and `@datasaver` decorators. In the visualization, we see the added `raw_data.loader` and `saved_data` nodes." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module materializers_module -d\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import dataloader, datasaver\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "@dataloader()\n", + "def raw_data() -> tuple[pd.DataFrame, dict]:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " data = pd.DataFrame(DATA)\n", + " metadata = {\"source\": \"notebook\", \"format\": \"json\"}\n", + " return data, metadata\n", + "\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.DataFrame:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum().to_frame()\n", + "\n", + "@datasaver()\n", + "def saved_data(amount_per_country: pd.DataFrame) -> dict:\n", + " amount_per_country.to_parquet(\"./saved_data.parquet\")\n", + " metadata = {\"source\": \"notebook\", \"format\": \"parquet\"}\n", + " return metadata" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we build a `Driver` as usual. " + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data.loader::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers_dr = (\n", + " driver.Builder()\n", + " .with_modules(materializers_module)\n", + " .with_cache()\n", + " .build()\n", + ")\n", + "\n", + "materializers_results = materializers_dr.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(materializers_results[\"amount_per_country\"].head())\n", + "print()\n", + "materializers_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We execute the dataflow a second time to show that loaders and savers are just like any other node; they can be cached and retrieved." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data.loader::result_store::get_result::hit\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::result_store::get_result::hit\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers_results = materializers_dr.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(materializers_results[\"amount_per_country\"].head())\n", + "print()\n", + "materializers_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Usage patterns\n", + "\n", + "Here are a few common scenarios:\n", + "\n", + "**Loading data is expensive**: Your dataflow uses a `DataLoader` to get data from Snowflake. You want to load it once and cache it. When executing your dataflow, you want to use your cached copy to save query time, egress costs, etc.\n", + "- Use the `DEFAULT` caching behavior for loaders.\n", + "\n", + "**Only save new data**: You run the dataflow multiple times (maybe with different parameters or on a schedule) and only want to write to destination when the data changes.\n", + "- Use the `DEFAULT` caching behavior for savers.\n", + "\n", + "**Always read the latest data**: You want to use caching, but also ensure the dataflow always uses the latest data. This involves executing the `DataLoader` every time, get the data in-memory, version it, and then determine what needs to be executed (see [Changing external data](#changing-external-data)).\n", + "- Use the `RECOMPUTE` caching behavior for loaders.\n", + "\n", + "Use the parameters `default_loader_behavior` or `default_saver_behavior` of the `.with_cache()` clause to specify the behavior for all loaders or savers.\n", + "\n", + "> NOTE. The **Caching + materializers tutorial** notebook details how to achieve granular control over loader and saver behaviors." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data.loader::adapter::execute_node\n", + "raw_data::adapter::execute_node\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "saved_data::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " amound_in_usd\n", + "country \n", + "Brazil 77.9004\n", + "Canada 941.9570\n", + "Mexico 46.2170\n", + "USA 2959.7600\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "raw_data.loader\n", + "\n", + "\n", + "raw_data.loader\n", + "raw_data()\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data.loader->raw_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "saved_data\n", + "\n", + "\n", + "saved_data\n", + "saved_data()\n", + "\n", + "\n", + "\n", + "amount_per_country->saved_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "materializer\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers_dr_2 = (\n", + " driver.Builder()\n", + " .with_modules(materializers_module)\n", + " .with_cache(\n", + " default_loader_behavior=\"recompute\",\n", + " default_saver_behavior=\"disable\"\n", + " )\n", + " .build()\n", + ")\n", + "\n", + "materializers_results_2 = materializers_dr_2.execute(\n", + " [\"amount_per_country\", \"saved_data\"],\n", + " inputs={\"cutoff_date\": \"2024-09-01\"}\n", + ")\n", + "print()\n", + "print(materializers_results_2[\"amount_per_country\"].head())\n", + "print()\n", + "materializers_dr_2.cache.view_run()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data.loader': ,\n", + " 'raw_data': ,\n", + " 'saved_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "materializers_dr_2.cache.behaviors[materializers_dr_2.cache.last_run_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Changing the cache format\n", + "\n", + "By default, results are stored in ``pickle`` format. It's a convenient default but [comes with caveats](https://grantjenks.com/docs/diskcache/tutorial.html#caveats). You can use the `@cache` decorator to specify another file format for storing results.\n", + "\n", + "By default this includes:\n", + "\n", + "- `json`\n", + "- `parquet`\n", + "- `csv`\n", + "- `excel`\n", + "- `file`\n", + "- `feather`\n", + "- `orc`\n", + "\n", + "This feature uses `DataLoader` and `DataSaver` under the hood and supports all of the same formats (including your custom ones, as long as they take a `path` attribute).\n", + "\n", + "> This is an area of active development. Feel free to share suggestions and feedback!\n", + "\n", + "The next cell sets `processed_data` to be cached using the `parquet` format." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "%%cell_to_module cache_format_module\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import cache\n", + "\n", + "DATA = {\n", + " \"cities\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Montréal\", \"Vancouver\", \"Houston\", \"Phoenix\", \"Mexico City\", \"Chihuahua City\", \"Rio de Janeiro\"],\n", + " \"date\": [\"2024-09-13\", \"2024-09-12\", \"2024-09-11\", \"2024-09-11\", \"2024-09-09\", \"2024-09-08\", \"2024-09-07\", \"2024-09-06\", \"2024-09-05\", \"2024-09-04\"],\n", + " \"amount\": [478.23, 251.67, 989.34, 742.14, 584.56, 321.85, 918.67, 135.22, 789.12, 432.78],\n", + " \"country\": [\"USA\", \"USA\", \"USA\", \"Canada\", \"Canada\", \"USA\", \"USA\", \"Mexico\", \"Mexico\", \"Brazil\"],\n", + " \"currency\": [\"USD\", \"USD\", \"USD\", \"CAD\", \"CAD\", \"USD\", \"USD\", \"MXN\", \"MXN\", \"BRL\"],\n", + "}\n", + "\n", + "def raw_data() -> pd.DataFrame:\n", + " \"\"\"Loading raw data. This simulates loading from a file, database, or external service.\"\"\"\n", + " return pd.DataFrame(DATA)\n", + "\n", + "@cache(format=\"parquet\")\n", + "def processed_data(raw_data: pd.DataFrame, cutoff_date: str) -> pd.DataFrame:\n", + " \"\"\"Filter out rows before cutoff date and convert currency to USD.\"\"\"\n", + " df = raw_data.loc[raw_data.date > cutoff_date].copy()\n", + " df[\"amound_in_usd\"] = df[\"amount\"]\n", + " df.loc[df.country == \"Canada\", \"amound_in_usd\"] *= 0.71 \n", + " df.loc[df.country == \"Brazil\", \"amound_in_usd\"] *= 0.18\n", + " df.loc[df.country == \"Mexico\", \"amound_in_usd\"] *= 0.05\n", + " return df\n", + "\n", + "def amount_per_country(processed_data: pd.DataFrame) -> pd.Series:\n", + " \"\"\"Sum the amount in USD per country\"\"\"\n", + " return processed_data.groupby(\"country\")[\"amound_in_usd\"].sum()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When executing the dataflow, we see `raw_data` recomputed because it's a dataloader. The result for `processed_data` will be retrieved, but it will be saved again as `.parquet` this time. " + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "country\n", + "Canada 941.957\n", + "USA 1719.240\n", + "Name: amound_in_usd, dtype: float64\n", + "\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "Series\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cache_format_dr = driver.Builder().with_modules(cache_format_module).with_cache().build()\n", + "\n", + "cache_format_results = cache_format_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "print()\n", + "print(cache_format_results[\"amount_per_country\"].head())\n", + "print()\n", + "cache_format_dr.cache.view_run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, under the `./.hamilton_cache`, there will be two results of the same name, one with the `.parquet` extension and one without. The one without is actually a pickeld `DataLoader` to retrieve the `.parquet` file.\n", + "\n", + "You can access the path programmatically via the `result_store._path_from_data_version(...)` method." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_version = cache_format_dr.cache.data_versions[cache_format_dr.cache.last_run_id][\"processed_data\"]\n", + "parquet_path = cache_format_dr.cache.result_store._path_from_data_version(data_version).with_suffix(\".parquet\")\n", + "parquet_path.exists()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introspecting the cache\n", + "The `Driver.cache` stores information about all executions over its lifetime. Previous `run_id` are available through `Driver.cache.run_ids` and can be used in tandem without other utility functions:\n", + "\n", + "- Resolve the node caching behavior (e.g., \"recompute\")\n", + "- Access structured logs\n", + "- Visualize the cache execution\n", + "\n", + "Also, `Driver.cache.last_run_id` is a shortcut to the most recent execution." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amount_per_country': ,\n", + " 'processed_data': ,\n", + " 'raw_data': ,\n", + " 'cutoff_date': }" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cache_format_dr.cache.resolve_behaviors(cache_format_dr.cache.last_run_id)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "processed_data::adapter::resolve_behavior\n", + "processed_data::adapter::set_cache_key\n", + "processed_data::adapter::get_cache_key::hit\n", + "processed_data::adapter::get_data_version::miss\n", + "processed_data::metadata_store::get_data_version::miss\n", + "processed_data::adapter::execute_node\n", + "processed_data::adapter::set_data_version\n", + "processed_data::metadata_store::set_data_version\n", + "processed_data::adapter::get_cache_key::hit\n", + "processed_data::adapter::get_data_version::hit\n", + "processed_data::result_store::set_result\n", + "processed_data::adapter::get_data_version::hit\n", + "processed_data::adapter::resolve_behavior\n" + ] + } + ], + "source": [ + "run_logs = cache_format_dr.cache.logs(cache_format_dr.cache.last_run_id, level=\"debug\")\n", + "for event in run_logs[\"processed_data\"]:\n", + " print(event)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "processed_data\n", + "\n", + "processed_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "amount_per_country\n", + "\n", + "amount_per_country\n", + "Series\n", + "\n", + "\n", + "\n", + "processed_data->amount_per_country\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "raw_data\n", + "\n", + "raw_data\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "raw_data->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_processed_data_inputs\n", + "\n", + "cutoff_date\n", + "str\n", + "\n", + "\n", + "\n", + "_processed_data_inputs->processed_data\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n", + "output\n", + "\n", + "output\n", + "\n", + "\n", + "\n", + "from cache\n", + "\n", + "from cache\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# for `.view_run()` passing no parameter is equivalent to the last `run_id`\n", + "cache_format_dr.cache.view_run(cache_format_dr.cache.last_run_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Interactively explore runs\n", + "By using `ipywidgets` we can easily build a widget to iterate over `run_id` values and display cache information. Below, we create a `Driver` and execute it a few times to generate data then inspect it with a widget." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "raw_data::result_store::get_result::hit\n", + "processed_data::result_store::get_result::hit\n", + "amount_per_country::result_store::get_result::hit\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::result_store::get_result::hit\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n", + "raw_data::result_store::get_result::hit\n", + "processed_data::adapter::execute_node\n", + "amount_per_country::adapter::execute_node\n" + ] + }, + { + "data": { + "text/plain": [ + "{'amount_per_country': Series([], Name: amound_in_usd, dtype: float64)}" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactive_dr = driver.Builder().with_modules(cache_format_module).with_cache().build()\n", + "\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-01\"})\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-05\"})\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-10\"})\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-11\"})\n", + "interactive_dr.execute([\"amount_per_country\"], inputs={\"cutoff_date\": \"2024-09-13\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following cell allows you to click-and-drag or use arrow-keys to navigate" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a9785e33191453bac0b952ce1f80ef3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "interactive(children=(SelectionSlider(description='run_id', options=('101f1759-82c3-416b-875b-e184b765af3c', '…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import display\n", + "from ipywidgets import SelectionSlider, interact\n", + "\n", + "\n", + "@interact(run_id=SelectionSlider(options=interactive_dr.cache.run_ids))\n", + "def iterate_over_runs(run_id):\n", + " display(interactive_dr.cache.data_versions[run_id])\n", + " display(interactive_dr.cache.view_run(run_id=run_id))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Managing storage\n", + "### Setting the cache `path`\n", + "\n", + "By default, metadata and results are stored under `./.hamilton_cache`, relative to the current directory at execution time. You can also manually set the directory via `.with_cache(path=...)` to isolate or centralize cache storage between dataflows or projects.\n", + "\n", + "Running the next cell will create the directory `./my_other_cache`." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "manual_path_dr = driver.Builder().with_modules(cache_format_module).with_cache(path=\"./my_other_cache\").build()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiating the `result_store` and `metadata_store`\n", + "If you need to store metadata and results in separate locations, you can do so by instantiating the `result_store` and `metadata_store` manually with their own configuration. In this case, setting `.with_cache(path=...)` would be ignored." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton.caching.stores.file import FileResultStore\n", + "from hamilton.caching.stores.sqlite import SQLiteMetadataStore\n", + "\n", + "result_store = FileResultStore(path=\"./results\")\n", + "metadata_store = SQLiteMetadataStore(path=\"./metadata\")\n", + "\n", + "manual_stores_dr = (\n", + " driver.Builder()\n", + " .with_modules(cache_format_module)\n", + " .with_cache(\n", + " result_store=result_store,\n", + " metadata_store=metadata_store,\n", + " )\n", + " .build()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Deleting data and recovering storage\n", + "As you use caching, you might be generating a lot of data that you don't need anymore. One straightforward solution is to delete the entire directory where metadata and results are stored. \n", + "\n", + "You can also programmatically call `.delete_all()` on the `result_store` and `metadata_store`, which should reclaim most storage. If you delete results, make sure to also delete metadata. The caching mechanism should figure it out, but it's safer to keep them in sync." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "manual_stores_dr.cache.metadata_store.delete_all()\n", + "manual_stores_dr.cache.result_store.delete_all()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage patterns\n", + "\n", + "As demonstrated here, caching works great in a notebook environment.\n", + "\n", + "- In addition to iteration speed, caching allows you to restart your kernel or shutdown your computer for the day without worry. When you'll come back, you will still be able to retrieve results from cache.\n", + "\n", + "- A similar benefit is the ability resume execution between environments. For example, you might be running Hamilton in a script, but when a bug happens you can reload these values in a notebook and investigate.\n", + "\n", + "- Caching works great with other adapters like the `HamiltonTracker` that powers the Hamilton UI and the `MLFlowTracker` for experiment tracking.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🚧 INTERNALS\n", + "If you're curious the following sections provide details about the caching internals. These APIs are not public and may change without notice." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Manually retrieve results\n", + "Using the `Driver.cache` you can directly retrieve results from previous executions. The cache stores \"data versions\" which are keys for the `result_store`. \n", + "\n", + "Here, we get the `run_id` for the 4th execution (index 3) and the data version for `processed_data` before retrieving its value." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.23\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.67\n" + ] + } + ], + "source": [ + "run_id = interactive_dr.cache.run_ids[3]\n", + "data_version = interactive_dr.cache.data_versions[run_id][\"processed_data\"]\n", + "result = interactive_dr.cache.result_store.get(data_version)\n", + "print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Decoding the `cache_key`\n", + "\n", + "By now, you should have a better grasp on how Hamilton's caching determines when to execute a node. Internally, it creates a `cache_key` from the `code_version` of the node and the `data_version` of each dependency. The cache keys are stored on the `Driver.cache` and can be decoded for introspection and debugging.\n", + "\n", + "Here, we get the `run_id` for the 3rd execution (index 2) and the cache key for `amount_per_country`. We then use `decode_key()` to retrieve the `node_name`, `code_version`, and `dependencies_data_versions`." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'node_name': 'amount_per_country',\n", + " 'code_version': 'c2ccafa54280fbc969870b6baa445211277d7e8cfa98a0821836c175603ffda2',\n", + " 'dependencies_data_versions': {'processed_data': 'WgV5-4SfdKTfUY66x-msj_xXsKNPNTP2guRhfw=='}}" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from hamilton.caching.cache_key import decode_key\n", + "\n", + "run_id = interactive_dr.cache.run_ids[2]\n", + "cache_key = interactive_dr.cache.cache_keys[run_id][\"amount_per_country\"]\n", + "decode_key(cache_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Indeed, this match the data version for `processed_data` for the 3rd execution." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'WgV5-4SfdKTfUY66x-msj_xXsKNPNTP2guRhfw=='" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "interactive_dr.cache.data_versions[run_id][\"processed_data\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Manually retrieve metadata\n", + "\n", + "In addition to the `result_store`, there is a `metadata_store` that contains mapping between `cache_key` and `data_version` (cache keys are unique, but many can point to the same data).\n", + "\n", + "Using the knowledge from the previous section, we can use the cache key for `amount_per_country` to retrieve its `data_version` and result. It's also possible to decode its `cache_key`, and get the `data_version` for its dependencies, making the node execution reproducible." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "country\n", + "Canada 526.9194\n", + "USA 1719.2400\n", + "Name: amound_in_usd, dtype: float64\n" + ] + } + ], + "source": [ + "run_id = interactive_dr.cache.run_ids[2]\n", + "cache_key = interactive_dr.cache.cache_keys[run_id][\"amount_per_country\"]\n", + "amount_data_version = interactive_dr.cache.metadata_store.get(cache_key)\n", + "amount_result = interactive_dr.cache.result_store.get(amount_data_version)\n", + "print(amount_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "processed_data\n", + " cities date amount country currency amound_in_usd\n", + "0 New York 2024-09-13 478.23 USA USD 478.23\n", + "1 Los Angeles 2024-09-12 251.67 USA USD 251.67\n", + "2 Chicago 2024-09-11 989.34 USA USD 989.34\n", + "3 Montréal 2024-09-11 742.14 Canada CAD 526.9194\n", + "\n" + ] + } + ], + "source": [ + "for dep_name, dependency_data_version in decode_key(cache_key)[\"dependencies_data_versions\"].items():\n", + " dep_result = interactive_dr.cache.result_store.get(dependency_data_version)\n", + " print(dep_name)\n", + " print(dep_result)\n", + " print()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/caching_nodes/caching.ipynb b/examples/caching_nodes/caching.ipynb new file mode 100644 index 000000000..e7186eb47 --- /dev/null +++ b/examples/caching_nodes/caching.ipynb @@ -0,0 +1,375 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# First-class Caching in Hamilton\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "from hamilton import registry\n", + "registry.disable_autoload()\n", + "registry.load_extension(\"pandas\")\n", + "\n", + "%load_ext hamilton.plugins.jupyter_magic" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "cluster__legend\n", + "\n", + "Legend\n", + "\n", + "\n", + "\n", + "G\n", + "\n", + "G\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "B\n", + "\n", + "B\n", + "float\n", + "\n", + "\n", + "\n", + "E\n", + "\n", + "E\n", + "str\n", + "\n", + "\n", + "\n", + "B->E\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "C\n", + "\n", + "C\n", + "bool\n", + "\n", + "\n", + "\n", + "B->C\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "F\n", + "\n", + "F\n", + "dict\n", + "\n", + "\n", + "\n", + "E->F\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "A\n", + "\n", + "A\n", + "int\n", + "\n", + "\n", + "\n", + "A->B\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "C->E\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_C_inputs\n", + "\n", + "D\n", + "bool\n", + "\n", + "\n", + "\n", + "_C_inputs->C\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "input\n", + "\n", + "input\n", + "\n", + "\n", + "\n", + "function\n", + "\n", + "function\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%cell_to_module primitives -d\n", + "import pandas as pd\n", + "from hamilton.function_modifiers import tag\n", + "\n", + "def A() -> int:\n", + " return 7\n", + "\n", + "def B(A: int) -> float:\n", + " return float(A)\n", + "\n", + "def C(B: float, D: bool) -> bool:\n", + " return B != D\n", + "\n", + "@tag(cache=\"pickle\")\n", + "def E(C: bool, B: float) -> str:\n", + " return \"hello-world-ok\" * int(B)\n", + "\n", + "@tag(cache=\"json\")\n", + "def F(E: str) -> dict:\n", + " return {E: E*3}\n", + "\n", + "@tag(cache=\"parquet\")\n", + "def G() -> pd.DataFrame:\n", + " return pd.DataFrame({\"a\": [323, 3235], \"b\": [\"hello\", \"vorld\"]})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from hamilton import driver\n", + "\n", + "dr = (\n", + " driver.Builder()\n", + " .with_modules(primitives)\n", + " .with_cache()\n", + " .build()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_keys(['A', 'B', 'C', 'D', 'E', 'F', 'G'])\n" + ] + } + ], + "source": [ + "results = dr.execute(\n", + " [\"A\", \"B\", \"C\", \"D\", \"E\", \"F\", \"G\"],\n", + " inputs=dict(D=True),\n", + " overrides=dict(B=4)\n", + ")\n", + "print(results.keys())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'A': 'eF4NxUEKwCAMBMCv9AlG48ZA6V+imx7t/28VBkaKJ5QqYa5LiXSzaMIMhuKULO2NAFnnmGDAZHQX9Hpc9/52Pj/UjxQx',\n", + " 'C': 'eF4FwcERgDAIBMBWLCGAEDLj2MvpwTP2/3O3SztyWAgrUb46ZoNSTq2Zr5xkBWD+eEKMiykEsEJHhh7X/nbdP+gBFJk=',\n", + " 'E': 'eF4FwckNgDAMBMBWKCHWOs5aQvQSfDxD/z9mAhT31VNFzDnENmoxM9L1taFMNfRcIk0HirsDbSQ6lPu6z3fq+QGUKRM7',\n", + " 'F': 'eF4FwcERwCAIBMBWUgLincBMJr2Iwafp/5ddiIc3WanjTUQ3zUYO8wEIW9AU6LL7jF1SKZ5qoPoszlJe9/lOPT9v4hJQ',\n", + " 'G': 'eF4FwckNgDAQA8BWKIEE7xEJ0YuJs8/Q/4+ZnvQQsoQl93yjsQuvZRhmBSVejiGiiXNwtThnucEJtzru/e31/NSBFG8='}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dr.cache.context_keys" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'metadata_store': ['initialized'],\n", + " 'A': ['caching_behavior:resolved',\n", + " 'code::versioned',\n", + " 'in_memory_metadata::get::miss',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'result_store::get::hit',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'in_memory_metadata::set'],\n", + " 'B': ['caching_behavior:resolved',\n", + " 'code::versioned',\n", + " 'as input',\n", + " 'data::versioned',\n", + " 'in_memory_metadata::set'],\n", + " 'C': ['caching_behavior:resolved',\n", + " 'code::versioned',\n", + " 'in_memory_metadata::get::miss',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'result_store::get::hit',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'in_memory_metadata::set'],\n", + " 'E': ['caching_behavior:resolved',\n", + " 'code::versioned',\n", + " 'in_memory_metadata::get::miss',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'result_store::get::hit',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'in_memory_metadata::set'],\n", + " 'F': ['caching_behavior:resolved',\n", + " 'code::versioned',\n", + " 'in_memory_metadata::get::miss',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'result_store::get::hit',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'in_memory_metadata::set'],\n", + " 'G': ['caching_behavior:resolved',\n", + " 'code::versioned',\n", + " 'in_memory_metadata::get::miss',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'result_store::get::hit',\n", + " 'context_key::created',\n", + " 'metadata_store::get::hit',\n", + " 'in_memory_metadata::set'],\n", + " 'D': ['caching_behavior:resolved',\n", + " 'code::versioned',\n", + " 'as input',\n", + " 'data::versioned',\n", + " 'in_memory_metadata::set']}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dr.cache.logs(dr.cache.run_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Caching actions\n", + "1. compute `code_version`\n", + "2. compute `data_version` for `value` using `hashing_function`\n", + "3. compute `value` by executing node with using `dependencies value`\n", + "4. create `context_key` using `code_version` and `dependencies data_version`\n", + "5. get `data_version` using `memory[node_name]`\n", + "6. set `data_version` using `memory[node_name]`\n", + "7. get `data_version` using `metadata_store[context_key]`\n", + "8. set `data_version` using `metadata_store[context_key]`\n", + "9. delete `data_version` using `metadata_store[context_key]`\n", + "10. get `value` from `result_store[data_version]`\n", + "11. set `value` from `result_store[data_version]`\n", + "12. store `value` using a materializer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Power user\n", + "- ignore dependencies from the `context_key` function\n", + "- skip `set`/`get` operations for `metadata_store` and `result_store` \n", + "- change the `hashing_function`\n", + "- change the `metadata_store`\n", + "- change the `result_store`\n", + "- change the `context_key` function" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/hamilton/caching/__init__.py b/hamilton/caching/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/hamilton/caching/adapter.py b/hamilton/caching/adapter.py new file mode 100644 index 000000000..c41bfa4fd --- /dev/null +++ b/hamilton/caching/adapter.py @@ -0,0 +1,1457 @@ +import collections +import dataclasses +import enum +import functools +import json +import logging +import pathlib +import uuid +from datetime import datetime, timezone +from typing import Any, Callable, Collection, Dict, List, Literal, Optional, TypeVar, Union + +import hamilton.node +from hamilton import graph_types +from hamilton.caching import fingerprinting +from hamilton.caching.cache_key import create_cache_key +from hamilton.caching.stores.base import ( + MetadataStore, + ResultRetrievalError, + ResultStore, + search_data_adapter_registry, +) +from hamilton.caching.stores.file import FileResultStore +from hamilton.caching.stores.sqlite import SQLiteMetadataStore +from hamilton.function_modifiers.metadata import cache as cache_decorator +from hamilton.graph import FunctionGraph +from hamilton.lifecycle.base import ( + BaseDoNodeExecute, + BasePostNodeExecute, + BasePreGraphExecute, + BasePreNodeExecute, +) + +logger = logging.getLogger("hamilton.caching") + +SENTINEL = object() +S = TypeVar("S", object, object) + + +CACHING_BEHAVIORS = Literal["default", "recompute", "disable", "ignore"] + + +class CachingBehavior(enum.Enum): + """Behavior applied by the caching adapter + + DEFAULT: + Try to retrieve result from cache instead of executing the node. If the node is executed, store the result. + Compute the result data version and store it too. + + RECOMPUTE: + Don't try to retrieve result from cache and always execute the node. Otherwise, behaves as default. + Useful when nodes are stochastic (e.g., model training) or interact with external + components (e.g., read from database). + + DISABLE: + Node is executed as if the caching feature wasn't enabled. + It never tries to retrieve results. Results are never stored nor versioned. + Behaves like IGNORE, but the node remains a dependency for downstream nodes. + This means downstream cache lookup will likely fail systematically (i.e., if the cache is empty). + + IGNORE: + Node is executed as if the caching feature wasn't enable. + It never tries to retrieve results. Results are never stored nor versioned. + IGNORE means downstream nodes will ignore this node as a dependency for lookup. + Ignoring clients and connections can be useful since they shouldn't directly impact the downstream results. + """ + + DEFAULT = 1 + RECOMPUTE = 2 + DISABLE = 3 + IGNORE = 4 + + @classmethod + def from_string(cls, string: str) -> "CachingBehavior": + """Create a caching behavior from a string of the enum value. This is + leveraged by the ``hamilton.lifecycle.caching.SmartCacheAdapter`` and + the ``hamilton.function_modifiers.metadata.cache`` decorator. + + .. code-block:: + + CachingBehavior.from_string("recompute") + + """ + try: + return cls[string.upper()] + except KeyError as e: + raise KeyError(f"{string} is an invalid `CachingBehavior` value") from e + + +class NodeRoleInTaskExecution(enum.Enum): + """Identify the role of a node in task-based execution, in particular when + ``Parallelizable/Collect`` are used. + + NOTE This is an internal construct and it will likely change in the future. + + STANDARD: when task-based execution is not used. All nodes and dependencies are STANDARD. + EXPAND: node with type ``Parallelizable``. It returns an iterator where individual items need to be handled. + Dependencies can only be OUTSIDE. + COLLECT: node with type ``Collect``. It returns an iterable where individual items need to be handled. + Dependencies can be INSIDE, OUTSIDE, or EXPAND + OUTSIDE: "outside" of ``Parallelizable/Collect`` paths; handled like STANDARD in most cases. + Dependencies can be OUTSIDE or COLLECT + INSIDE: "inside" or "between" a ``Parallelizable/Collect`` nodes. + Dependencies can be INSIDE, OUTSIDE, or EXPAND. + """ + + STANDARD = 1 + EXPAND = 2 + COLLECT = 3 + OUTSIDE = 4 + INSIDE = 5 + + +class CachingEventType(enum.Enum): + """Event types logged by the caching adapter""" + + GET_DATA_VERSION = "get_data_version" + SET_DATA_VERSION = "set_data_version" + GET_CACHE_KEY = "get_cache_key" + SET_CACHE_KEY = "set_cache_key" + GET_RESULT = "get_result" + SET_RESULT = "set_result" + MISSING_RESULT = "missing_result" + FAILED_RETRIEVAL = "failed_retrieval" + EXECUTE_NODE = "execute_node" + FAILED_EXECUTION = "failed_execution" + RESOLVE_BEHAVIOR = "resolve_behavior" + UNHASHABLE_DATA_VERSION = "unhashable_data_version" + IS_OVERRIDE = "is_override" + IS_INPUT = "is_input" + IS_FINAL_VAR = "is_final_var" + IS_DEFAULT_PARAMETER_VALUE = "is_default_parameter_value" + + +@dataclasses.dataclass(frozen=True) +class CachingEvent: + """Event logged by the caching adapter""" + + run_id: str + actor: Literal["adapter", "metadata_store", "result_store"] + event_type: CachingEventType + node_name: str + task_id: Optional[str] = None + msg: Optional[str] = None + value: Optional[Any] = None + timestamp: float = dataclasses.field( + default_factory=lambda: datetime.now(timezone.utc).timestamp() + ) + + def __str__(self) -> str: + """Create a human-readable string format for `print()`""" + + string = self.node_name + if self.task_id is not None: + string += f"::{self.task_id}" + string += f"::{self.actor}" + string += f"::{self.event_type.value}" + if self.msg: # this catches None and empty strings + string += f"::{self.msg}" + + return string + + def as_dict(self): + return dict( + run_id=self.run_id, + timestamp=self.timestamp, + node_name=self.node_name, + task_id=self.task_id, + actor=self.actor, + event_type=self.event_type.value, + msg=self.msg, + value=str(self.value) if self.value else self.value, + ) + + +# TODO we could add a "driver-level" kwarg to specify the cache format (e.g., parquet, JSON, etc.) +class HamiltonCacheAdapter( + BaseDoNodeExecute, BasePreGraphExecute, BasePostNodeExecute, BasePreNodeExecute +): + """Adapter enabling Hamilton's caching feature through ``Builder.with_cache()`` + + .. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = ( + driver.Builder() + .with_modules(my_dataflow) + .with_cache() + .build() + ) + + # then, you can access the adapter via + dr.cache + + """ + + def __init__( + self, + path: Union[str, pathlib.Path] = ".hamilton_cache", + metadata_store: Optional[MetadataStore] = None, + result_store: Optional[ResultStore] = None, + default: Optional[Union[Literal[True], Collection[str]]] = None, + recompute: Optional[Union[Literal[True], Collection[str]]] = None, + ignore: Optional[Union[Literal[True], Collection[str]]] = None, + disable: Optional[Union[Literal[True], Collection[str]]] = None, + default_behavior: Optional[CACHING_BEHAVIORS] = None, + default_loader_behavior: Optional[CACHING_BEHAVIORS] = None, + default_saver_behavior: Optional[CACHING_BEHAVIORS] = None, + log_to_file: bool = False, + **kwargs, + ): + """Initialize the cache adapter. + + :param path: path where the cache metadata and results will be stored + :param metadata_store: BaseStore handling metadata for the cache adapter + :param result_store: BaseStore caching dataflow execution results + :param default: Set caching behavior to DEFAULT for specified node names. If True, apply to all nodes. + :param recompute: Set caching behavior to RECOMPUTE for specified node names. If True, apply to all nodes. + :param ignore: Set caching behavior to IGNORE for specified node names. If True, apply to all nodes. + :param disable: Set caching behavior to DISABLE for specified node names. If True, apply to all nodes. + :param default_behavior: Set the default caching behavior. + :param default_loader_behavior: Set the default caching behavior `DataLoader` nodes. + :param default_saver_behavior: Set the default caching behavior `DataSaver` nodes. + :param log_to_file: If True, append cache event logs as they happen in JSONL format. + """ + self._path = path + self.metadata_store = ( + metadata_store if metadata_store is not None else SQLiteMetadataStore(path=path) + ) + self.result_store = ( + result_store if result_store is not None else FileResultStore(path=str(path)) + ) + self.log_to_file = log_to_file + + if sum([default is True, recompute is True, disable is True, ignore is True]) > 1: + raise ValueError( + "Can only set one of (`default`, `recompute`, `disable`, `ignore`) to True. Please pass mutually exclusive sets of node names" + ) + self._default = default + self._recompute = recompute + self._disable = disable + self._ignore = ignore + self.default_behavior = default_behavior + self.default_loader_behavior = default_loader_behavior + self.default_saver_behavior = default_saver_behavior + + # attributes populated at execution time + self.run_ids: List[str] = [] + self._fn_graphs: Dict[str, FunctionGraph] = {} # {run_id: graph} + self._data_savers: Dict[str, Collection[str]] = {} # {run_id: list[node_name]} + self._data_loaders: Dict[str, Collection[str]] = {} # {run_id: list[node_name]} + self.behaviors: Dict[ + str, Dict[str, CachingBehavior] + ] = {} # {run_id: {node_name: behavior}} + self.data_versions: Dict[ + str, Dict[str, Union[str, Dict[str, str]]] + ] = {} # {run_id: {node_name: version}} or {run_id: {node_name: {task_id: version}}} + self.code_versions: Dict[str, Dict[str, str]] = {} # {run_id: {node_name: version}} + self.cache_keys: Dict[ + str, Dict[str, Union[str, Dict[str, str]]] + ] = {} # {run_id: {node_name: key}} or {run_id: {node_name: {task_id: key}}} + self._logs: Dict[str, List[CachingEvent]] = {} # {run_id: [logs]} + + @property + def last_run_id(self): + """Run id of the last started run. Not necessarily the last to complete.""" + return self.run_ids[-1] + + def __getstate__(self) -> dict: + """Serialization method required for multiprocessing and multithreading + when using task-based execution with `Parallelizable/Collect` + """ + state = self.__dict__.copy() + # store the classes to reinstantiate the same backend in __setstate__ + state["metadata_store_cls"] = self.metadata_store.__class__ + state["result_store_cls"] = self.result_store.__class__ + del state["metadata_store"] + del state["result_store"] + return state + + def __setstate__(self, state: dict) -> None: + """Serialization method required for multiprocessing and multithreading + when using task-based execution with `Parallelizable/Collect`. + + Create new instances of metadata and result stores to have one connection + per thread. + """ + # instantiate the backend from the class, then delete the attribute before + # setting it on the adapter instance. + self.metadata_store = state["metadata_store_cls"](path=state["_path"]) + self.result_store = state["result_store_cls"](path=state["_path"]) + del state["metadata_store_cls"] + del state["result_store_cls"] + self.__dict__.update(state) + + def _log_event( + self, + run_id: str, + node_name: str, + actor: Literal["adapter", "metadata_store", "result_store"], + event_type: CachingEventType, + msg: Optional[str] = None, + value: Optional[Any] = None, + task_id: Optional[str] = None, + ) -> None: + """Add a single event to logs stored in state, keyed by run_id + + If global log level is set to logging.INFO, only log if event type is GET_RESULT or EXECUTE_NODE; + If it is set to logging.DEBUG, log all events. + + If `SmartCacheAdapter.log_to_file` is set to True, log all events to a file in JSONL format. + + :param node_name: name of the node associated with the event + :param task_id: optional identifier when using task-based execution. (node_name, task_id) is a primary key + :param actor: component responsible for the event + :param event_type: enum specifying what type of event (execute, retrieve, etc.) + :param msg: additional message to display in the logs (e.g., via terminal) + :param value: arbitrary value to include (typically a string for data version, code version, cache_key). Must be small and JSON-serializable. + """ + event = CachingEvent( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor=actor, + event_type=event_type, + msg=msg, + value=value, + ) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"{event.__str__()}") + elif logger.isEnabledFor(logging.INFO): + if event.event_type in (CachingEventType.GET_RESULT, CachingEventType.EXECUTE_NODE): + logger.info(f"{event.__str__()}") + + self._logs[run_id].append(event) + + if self.log_to_file: + log_file_path = pathlib.Path(self.metadata_store._directory, "cache_logs.jsonl") + json_line = json.dumps(event.as_dict()) + with log_file_path.open("a") as f: + f.write(json_line + "\n") + + def _log_by_node_name( + self, run_id: str, level: Literal["debug", "info"] = "info" + ) -> Dict[str, List[str]]: + """For a given run, group logs to key them by ``node_name`` or ``(node_name, run_id)`` if applicable.""" + run_logs = collections.defaultdict(list) + for event in self._logs[run_id]: + if level == "info": + if event.event_type not in ( + CachingEventType.GET_RESULT, + CachingEventType.EXECUTE_NODE, + ): + continue + + key = (event.node_name, event.task_id) if event.task_id else event.node_name + run_logs[key].append(event) + return dict(run_logs) + + def logs(self, run_id: Optional[str] = None, level: Literal["debug", "info"] = "info") -> dict: + """Execution logs of the cache adapter. + + :param run_id: If ``None``, return all logged runs. If provided a ``run_id``, group logs by node. + :param level: If ``"debug"`` log all events. If ``"info"`` only log if result is retrieved or executed. + :return: a mapping between node/task and a list of logged events + + .. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = driver.Builder().with_modules(my_dataflow).with_cache().build() + dr.execute(...) + dr.execute(...) + + all_logs = dr.cache.logs() + # all_logs is a dictionary with run_ids as keys and lists of CachingEvent as values. + # { + # run_id_1: [CachingEvent(...), CachingEvent(...)], + # run_id_2: [CachingEvent(...), CachingEvent(...)], + # } + + + run_logs = dr.cache.logs(run_id=dr.last_run_id) + # run_logs are keyed by ``node_name`` + # {node_name: [CachingEvent(...), CachingEvent(...)], ...} + # or ``(node_name, task_id)`` if task-based execution is used. + # {(node_name_1, task_id_1): [CachingEvent(...), CachingEvent(...)], ...} + + """ + if run_id: + return self._log_by_node_name(run_id=run_id, level=level) + + logs = collections.defaultdict(list) + for run_id, run_logs in self._logs.items(): + for event in run_logs: + if level == "info" and event.event_type not in ( + CachingEventType.GET_RESULT, + CachingEventType.EXECUTE_NODE, + ): + continue + + logs[run_id].append(event) + + return dict(logs) + + @staticmethod + def _view_run( + fn_graph: FunctionGraph, + logs, + final_vars: List[str], + inputs: dict, + overrides: dict, + output_file_path: Optional[str] = None, + ): + """Create a Hamilton visualization of the execution and the cache hits/misses. + + This leverages the ``custom_style_function`` feature internally. + """ + from hamilton.driver import Driver # avoid circular import + + def _visualization_styling_function(*, node, node_class, logs): + """Custom style function for the visualization.""" + if any( + event.event_type == CachingEventType.GET_RESULT for event in logs.get(node.name, []) + ): + style = ( + {"penwidth": "3", "color": "#F06449", "fillcolor": "#ffffff"}, + node_class, + "from cache", + ) + else: + style = ({}, node_class, None) + + return style + + return Driver._visualize_execution_helper( + adapter=None, + bypass_validation=True, + render_kwargs={}, + output_file_path=output_file_path, + fn_graph=fn_graph, + final_vars=final_vars, + inputs=inputs, + overrides=overrides, + custom_style_function=functools.partial(_visualization_styling_function, logs=logs), + ) + + # TODO make this work directly from the metadata_store too + # visualization from logs is convenient when debugging someone else's issue + def view_run(self, run_id: Optional[str] = None, output_file_path: Optional[str] = None): + """View the dataflow execution, including cache hits/misses. + + :param run_id: If ``None``, view the last run. If provided a ``run_id``, view that run. + :param output_file_path: If provided a path, save the visualization to a file. + + .. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = driver.Builder().with_modules(my_dataflow).with_cache().build() + + # execute 3 times + dr.execute(...) + dr.execute(...) + dr.execute(...) + + # view the last run + dr.cache.view_run() + # this is equivalent to + dr.cache.view_run(run_id=dr.last_run_id) + + # get a specific run id + run_id = dr.cache.run_ids[1] + dr.cache.view_run(run_id=run_id) + + """ + if run_id is None: + run_id = self.last_run_id + + fn_graph = self._fn_graphs[run_id] + logs = self.logs(run_id, level="debug") + + final_vars = [] + inputs = {} + overrides = {} + for key, events in logs.items(): + if isinstance(key, tuple): + raise ValueError( + "`.view()` is currently not supported for task-based execution. " + "Please inspect the logs directly via `.logs(run_id=...)` for debugging." + ) + + node_name = key + if any(e.event_type == CachingEventType.IS_FINAL_VAR for e in events): + final_vars.append(node_name) + + if any(e.event_type == CachingEventType.IS_INPUT for e in events): + inputs[node_name] = None # the value doesn't matter, only the key of the dict + continue + + elif any(e.event_type == CachingEventType.IS_OVERRIDE for e in events): + overrides[node_name] = None # the value doesn't matter, only the key of the dict + continue + + return self._view_run( + fn_graph=fn_graph, + logs=logs, + final_vars=final_vars, + inputs=inputs, + overrides=overrides, + output_file_path=output_file_path, + ) + + def _get_node_role( + self, run_id: str, node_name: str, task_id: Optional[str] + ) -> NodeRoleInTaskExecution: + """Determine based on the node name and task_id if a node is part of parallel execution.""" + if task_id is None: + role = NodeRoleInTaskExecution.STANDARD + else: + node_type: hamilton.node.NodeType = self._fn_graphs[run_id].nodes[node_name].node_role + if node_type == hamilton.node.NodeType.EXPAND: + role = NodeRoleInTaskExecution.EXPAND + elif node_type == hamilton.node.NodeType.COLLECT: + role = NodeRoleInTaskExecution.COLLECT + elif node_name == task_id: + role = NodeRoleInTaskExecution.OUTSIDE + else: + role = NodeRoleInTaskExecution.INSIDE + + return role + + def get_cache_key( + self, run_id: str, node_name: str, task_id: Optional[str] = None + ) -> Union[str, S]: + """Get the ``cache_key`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. + + This method is public-facing and can be used directly to inspect the cache. + + :param run_id: Id of the Hamilton execution run. + :param node_name: Name of the node associated with the cache key. ``node_name`` is a unique identifier + if task-based execution is not used. + :param task_id: Id of the task when task-based execution is used. Then, the tuple ``(node_name, task_id)`` + is a unique identifier. + :return: The cache key if it exists, otherwise return a sentinel value. + + .. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = driver.Builder().with_modules(my_dataflow).with_cache().build() + dr.execute(...) + + dr.cache.get_cache_key(run_id=dr.last_run_id, node_name="my_node", task_id=None) + + """ + node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) + + if node_role == NodeRoleInTaskExecution.INSIDE: + cache_key = self.cache_keys[run_id].get(node_name, {}).get(task_id, SENTINEL) # type: ignore ; `task_id` can't be None + else: + cache_key = self.cache_keys[run_id].get(node_name, SENTINEL) + + cache_key = cache_key if cache_key is not SENTINEL else None + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="adapter", + event_type=CachingEventType.GET_CACHE_KEY, + msg="hit" if cache_key is not SENTINEL else "miss", + value=cache_key, + ) + return cache_key + + def _set_cache_key( + self, run_id: str, node_name: str, cache_key: str, task_id: Optional[str] = None + ) -> None: + """Set the ``cache_key`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. + + When calling this method, ``cache_key`` must not be ``None``. + """ + assert cache_key is not None + node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) + if node_role in ( + NodeRoleInTaskExecution.STANDARD, + NodeRoleInTaskExecution.OUTSIDE, + NodeRoleInTaskExecution.EXPAND, + NodeRoleInTaskExecution.COLLECT, + ): + self.cache_keys[run_id][node_name] = cache_key + elif node_role == NodeRoleInTaskExecution.INSIDE: + if self.cache_keys[run_id].get(node_name, SENTINEL) is SENTINEL: + self.cache_keys[run_id][node_name] = {} + self.cache_keys[run_id][node_name][task_id] = cache_key # type: ignore ; we just initialized the nested dict + else: + raise ValueError( + f"Received `{node_role}`. Unhandled `NodeRoleInTaskExecution`, please report this bug." + ) + + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="adapter", + event_type=CachingEventType.SET_CACHE_KEY, + value=cache_key, + ) + + def _get_memory_data_version( + self, run_id: str, node_name: str, task_id: Optional[str] = None + ) -> Union[str, S]: + """Get the ``data_version`` stored in-memory for a specific ``run_id``, ``node_name``, and ``task_id``. + + The behavior depends on the ``CacheBehavior`` (e.g., RECOMPUTE, IGNORE, DISABLE, DEFAULT) and + the ``NodeRoleInTaskExecution`` of the node (e.g., STANDARD, OUTSIDE, INSIDE, EXPAND, COLLECT). + + :param run_id: Id of the Hamilton execution run. + :param node_name: Name of the node associated with the cache key. ``node_name`` is a unique identifier + if task-based execution is not used. + :param task_id: Id of the task when task-based execution is used. Then, the tuple ``(node_name, task_id)`` + is a unique identifier. + """ + node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) + if node_role in ( + NodeRoleInTaskExecution.STANDARD, + NodeRoleInTaskExecution.OUTSIDE, + NodeRoleInTaskExecution.COLLECT, + ): + data_version = self.data_versions[run_id].get(node_name, SENTINEL) + elif node_role == NodeRoleInTaskExecution.EXPAND: + data_version = SENTINEL + elif node_role == NodeRoleInTaskExecution.INSIDE: + tasks_data_versions = self.data_versions[run_id].get(node_name, SENTINEL) + if isinstance(tasks_data_versions, dict): + data_version = tasks_data_versions.get(task_id, SENTINEL) + else: + data_version = SENTINEL + else: + raise ValueError( + f"Received `{node_role}`. Unhandled `NodeRoleInTaskExecution`, please report this bug." + ) + + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="adapter", + event_type=CachingEventType.GET_DATA_VERSION, + msg="hit" if data_version is not SENTINEL else "miss", + ) + return data_version + + def _get_stored_data_version( + self, run_id: str, node_name: str, cache_key: str, task_id: Optional[str] = None + ) -> Union[str, S]: + """Get the ``data_version`` stored in the metadata store associated with the ``cache_key``. + + The ``run_id``, ``node_name``, and ``task_id`` are included only for logging purposes. + """ + data_version = self.metadata_store.get(cache_key=cache_key) + data_version = SENTINEL if data_version is None else data_version + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="metadata_store", + event_type=CachingEventType.GET_DATA_VERSION, + msg="hit" if data_version is not SENTINEL else "miss", + ) + + return data_version + + def get_data_version( + self, + run_id: str, + node_name: str, + cache_key: Optional[str] = None, + task_id: Optional[str] = None, + ) -> Union[str, S]: + """Get the ``data_version`` for a specific ``run_id``, ``node_name``, and ``task_id``. + + This method is public-facing and can be used directly to inspect the cache. This will check data versions + stored both in-memory and in the metadata store. + + :param run_id: Id of the Hamilton execution run. + :param node_name: Name of the node associated with the data version. ``node_name`` is a unique identifier + if task-based execution is not used. + :param task_id: Id of the task when task-based execution is used. Then, the tuple ``(node_name, task_id)`` + is a unique identifier. + :return: The data version if it exists, otherwise return a sentinel value. + + ..code-block:: python + + from hamilton import driver + import my_dataflow + + dr = driver.Builder().with_modules(my_dataflow).with_cache().build() + dr.execute(...) + + dr.cache.get_data_version(run_id=dr.last_run_id, node_name="my_node", task_id=None) + + """ + + data_version = self._get_memory_data_version( + run_id=run_id, node_name=node_name, task_id=task_id + ) + + if data_version is SENTINEL and cache_key is not None: + data_version = self._get_stored_data_version( + run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key + ) + + return data_version + + def _set_memory_metadata( + self, run_id: str, node_name: str, data_version: str, task_id: Optional[str] = None + ) -> None: + """Set in-memory data_version whether a task_id is specified or not""" + assert data_version is not None + node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) + if node_role in ( + NodeRoleInTaskExecution.STANDARD, + NodeRoleInTaskExecution.OUTSIDE, + NodeRoleInTaskExecution.COLLECT, + ): + self.data_versions[run_id][node_name] = data_version + elif node_role == NodeRoleInTaskExecution.EXPAND: + self.data_versions[run_id][node_name] = {} + elif node_role == NodeRoleInTaskExecution.INSIDE: + if self.data_versions[run_id].get(node_name, SENTINEL) is SENTINEL: + self.data_versions[run_id][node_name] = {} + self.data_versions[run_id][node_name][task_id] = data_version # type: ignore ; we just initialized the nested dict + else: + raise ValueError( + f"Received `{node_role}`. Unhandled `NodeRoleInTaskExecution`, please report this bug." + ) + + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="adapter", + event_type=CachingEventType.SET_DATA_VERSION, + value=data_version, + ) + + def _set_stored_metadata( + self, + run_id: str, + node_name: str, + cache_key: str, + data_version: str, + task_id: Optional[str] = None, + ) -> None: + """Set data_version in the metadata store associated with the cache_key""" + self.metadata_store.set( + run_id=run_id, + node_name=node_name, + code_version=self.code_versions[run_id][node_name], + data_version=data_version, + cache_key=cache_key, + ) + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="metadata_store", + event_type=CachingEventType.SET_DATA_VERSION, + value=data_version, + ) + + def _version_data( + self, node_name: str, run_id: str, result: Any, task_id: Optional[str] = None + ) -> str: + """Create a unique data version for the result""" + data_version = fingerprinting.hash_value(result) + if data_version == fingerprinting.UNHASHABLE: + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="adapter", + event_type=CachingEventType.UNHASHABLE_DATA_VERSION, + msg=f"unhashable type {type(result)}; set CachingBehavior.IGNORE to silence warning", + value=data_version, + ) + logger.warning( + f"Node `{node_name}` has unhashable result of type `{type(result)}`. " + "Set `CachingBehavior.IGNORE` or register a versioning function to silence warning. " + "Learn more: https://hamilton.dagworks.io/en/latest/concepts/caching/#caching-behavior\n" + ) + # if the data version is unhashable, we need to set a random suffix to the cache_key + # to prevent the cache from thinking this value is constant, causing a cache hit. + data_version = "" + f"_{uuid.uuid4()}" + + return data_version + + def version_data(self, result: Any, run_id: str = None) -> str: + """Create a unique data version for the result + + This is a user-facing method. + """ + # stuff the internal function call to not log event + return self._version_data(result=result, run_id=run_id, node_name=None) + + def version_code(self, node_name: str, run_id: Optional[str] = None) -> str: + """Create a unique code version for the source code defining the node""" + run_id = self.last_run_id if run_id is None else run_id + node = self._fn_graphs[run_id].nodes[node_name] + return graph_types.HamiltonNode.from_node(node).version # type: ignore + + def _execute_node( + self, + run_id: str, + node_name: str, + node_callable: Callable, + node_kwargs: Dict[str, Any], + task_id: Optional[str] = None, + ) -> Any: + """Simple wrapper that logs the regular execution of a node.""" + logger.debug(node_name) + result = node_callable(**node_kwargs) + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="adapter", + event_type=CachingEventType.EXECUTE_NODE, + ) + return result + + @staticmethod + def _resolve_node_behavior( + node: hamilton.node.Node, + default: Optional[Collection[str]] = None, + disable: Optional[Collection[str]] = None, + recompute: Optional[Collection[str]] = None, + ignore: Optional[Collection[str]] = None, + default_behavior: CACHING_BEHAVIORS = "default", + default_loader_behavior: CACHING_BEHAVIORS = "default", + default_saver_behavior: CACHING_BEHAVIORS = "default", + ) -> CachingBehavior: + """Determine the cache behavior of a node. + Behavior specified via the ``Builder`` has precedence over the ``@cache`` decorator. + Otherwise, set the ``DEFAULT`` behavior. + If the node is `Parallelizable` enforce the ``RECOMPUTE`` behavior to ensure + yielded items are versioned individually. + """ + if node.node_role == hamilton.node.NodeType.EXPAND: + return CachingBehavior.RECOMPUTE + + behavior_from_tag = node.tags.get(cache_decorator.BEHAVIOR_KEY, SENTINEL) + if behavior_from_tag is not SENTINEL: + behavior_from_tag = CachingBehavior.from_string(behavior_from_tag) + + behavior_from_driver = SENTINEL + for behavior, node_set in ( + (CachingBehavior.DEFAULT, default), + (CachingBehavior.DISABLE, disable), + (CachingBehavior.RECOMPUTE, recompute), + (CachingBehavior.IGNORE, ignore), + ): + # guard against default None value + if node_set is None: + continue + + if node.name in node_set: + if behavior_from_driver is not SENTINEL: + raise ValueError( + f"Multiple caching behaviors specified by Driver for node: {node.name}" + ) + behavior_from_driver = behavior + + if behavior_from_driver is not SENTINEL: + return behavior_from_driver + elif behavior_from_tag is not SENTINEL: + return behavior_from_tag + elif node.tags.get("hamilton.data_loader"): + return CachingBehavior.from_string(default_loader_behavior) + elif node.tags.get("hamilton.data_saver"): + return CachingBehavior.from_string(default_saver_behavior) + else: + return CachingBehavior.from_string(default_behavior) + + def resolve_behaviors(self, run_id: str) -> Dict[str, CachingBehavior]: + """Resolve the caching behavior for each node based on the ``@cache`` decorator + and the ``Builder.with_cache()`` parameters for a specific ``run_id``. + + This is a user-facing method. + + Behavior specified via ``Builder.with_cache()`` have precedence. If no parameters are specified, + the ``CachingBehavior.DEFAULT`` is used. If a node is ``Parallelizable`` (i.e., ``@expand``), + the ``CachingBehavior`` is set to ``CachingBehavior.RECOMPUTE`` to ensure the yielded items + are versioned individually. Internally, this uses the ``FunctionGraph`` stored for each ``run_id`` and logs + the resolved caching behavior for each node. + + :param run_id: Id of the Hamilton execution run. + :return: A dictionary of ``{node name: caching behavior}``. + """ + graph = self._fn_graphs[run_id] + + _default = self._default + _disable = self._disable + _recompute = self._recompute + _ignore = self._ignore + + if _default is True: + _default = [n.name for n in graph.get_nodes()] + elif _disable is True: + _disable = [n.name for n in graph.get_nodes()] + elif _recompute is True: + _recompute = [n.name for n in graph.get_nodes()] + elif _ignore is True: + _ignore = [n.name for n in graph.get_nodes()] + + default_behavior = "default" + if self.default_behavior is not None: + default_behavior = self.default_behavior + + default_loader_behavior = default_behavior + if self.default_loader_behavior is not None: + default_loader_behavior = self.default_loader_behavior + + default_saver_behavior = default_behavior + if self.default_saver_behavior is not None: + default_saver_behavior = self.default_saver_behavior + + behaviors = {} + for node in graph.get_nodes(): + behavior = HamiltonCacheAdapter._resolve_node_behavior( + node=node, + default=_default, + disable=_disable, + recompute=_recompute, + ignore=_ignore, + default_behavior=default_behavior, + default_loader_behavior=default_loader_behavior, + default_saver_behavior=default_saver_behavior, + ) + behaviors[node.name] = behavior + + self._log_event( + run_id=run_id, + node_name=node.name, + task_id=None, + actor="adapter", + event_type=CachingEventType.RESOLVE_BEHAVIOR, + value=behavior, + ) + + # need to handle materializers via a second pass to copy the behavior + # of their "main node" + for node in graph.get_nodes(): + if node.tags.get("hamilton.data_loader") is True: + main_node = node.tags["hamilton.data_loader.node"] + if main_node == node.name: + continue + + # solution for `@dataloader` and `from_` + if behaviors.get(main_node, None) is not None: + behaviors[node.name] = behaviors[main_node] + # this hacky section is required to support @load_from and provide + # a unified pattern to specify behavior from the module or the driver + else: + behaviors[node.name] = HamiltonCacheAdapter._resolve_node_behavior( + # we create a fake node, only its name matters + node=hamilton.node.Node( + name=main_node, + typ=str, + callabl=lambda: None, + tags=node.tags.copy(), + ), + default=_default, + disable=_disable, + recompute=_recompute, + ignore=_ignore, + default_behavior=default_loader_behavior, + ) + + self._data_loaders[run_id].append(main_node) + + if node.tags.get("hamilton.data_saver", None) is not None: + self._data_savers[run_id].append(node.name) + + return behaviors + + def resolve_code_versions( + self, + run_id: str, + final_vars: Optional[List[str]] = None, + inputs: Optional[Dict[str, Any]] = None, + overrides: Optional[Dict[str, Any]] = None, + ) -> Dict[str, str]: + """Resolve the code version for each node for a specific ``run_id``. + + This is a user-facing method. + + If ``final_vars`` is None, all nodes will be versioned. If ``final_vars`` is provided, + the ``inputs`` and ``overrides`` are used to determine the execution path and only + version the code for these nodes. + + :param run_id: Id of the Hamilton execution run. + :param final_vars: Nodes requested for execution. + :param inputs: Input node values. + :param overrides: Override node values. + :return: A dictionary of ``{node name: code version}``. + """ + graph = self._fn_graphs[run_id] + + final_vars = [] if final_vars is None else final_vars + inputs = {} if inputs is None else inputs + overrides = {} if overrides is None else overrides + + node_selection = graph.get_nodes() + if len(final_vars) > 0: + all_nodes, user_defined_nodes = graph.get_upstream_nodes(final_vars, inputs, overrides) + node_selection = set(all_nodes) - set(user_defined_nodes) + + return { + node.name: self.version_code(run_id=run_id, node_name=node.name) + for node in node_selection + } + + def _process_input(self, run_id: str, node_name: str, value: Any) -> None: + """Process input nodes to version data and code. + + To enable caching, input values must be versioned. Since inputs have no associated code, + set a constant "code version" ``f"input__{node_name}"`` that uniquely identifies this input. + """ + data_version = self._version_data(node_name=node_name, run_id=run_id, result=value) + self.code_versions[run_id][node_name] = f"input__{node_name}" + self.data_versions[run_id][node_name] = data_version + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=None, + actor="adapter", + event_type=CachingEventType.IS_INPUT, + value=data_version, + ) + + def _process_override(self, run_id: str, node_name: str, value: Any) -> None: + """Process override nodes to version data and code. + + To enable caching, override values must be versioned. As opposed to executed nodes, + code and data versions for overrides are not stored because their value is user provided + and isn't necessarily tied to the code. + """ + data_version = self._version_data(node_name=node_name, run_id=run_id, result=value) + self.data_versions[run_id][node_name] = data_version + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=None, + actor="adapter", + event_type=CachingEventType.IS_OVERRIDE, + value=data_version, + ) + + @staticmethod + def _resolve_default_parameter_values( + node_: hamilton.node.Node, node_kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + """ + If a node uses the function's default parameter values, they won't be part of the + node_kwargs. To ensure a consistent `cache_key` we want to retrieve default parameter + values if they're used + """ + resolved_kwargs = node_kwargs.copy() + for param_name, param_value in node_.default_parameter_values.items(): + # if the `param_name` not in `node_kwargs`, it means the node uses the default + # parameter value + if param_name not in node_kwargs.keys(): + resolved_kwargs.update(**{param_name: param_value}) + + return resolved_kwargs + + def pre_graph_execute( + self, + *, + run_id: str, + graph: FunctionGraph, + final_vars: List[str], + inputs: Dict[str, Any], + overrides: Dict[str, Any], + ): + """Set up the state of the adapter for a new execution. + + Most attributes need to be keyed by run_id to prevent potential conflicts because + the same adapter instance is shared between across all ``Driver.execute()`` calls. + """ + self.run_ids.append(run_id) + self.metadata_store.initialize(run_id) + self._logs[run_id] = [] + + self._fn_graphs[run_id] = graph + self.data_versions[run_id] = {} + self.cache_keys[run_id] = {} + self.code_versions[run_id] = self.resolve_code_versions( + run_id=run_id, final_vars=final_vars, inputs=inputs, overrides=overrides + ) + # the empty `._data_loaders` and `._data_savers` need to be instantiated before calling + # `self.resolve_behaviors` because it appends to them + self._data_loaders[run_id] = [] + self._data_savers[run_id] = [] + self.behaviors[run_id] = self.resolve_behaviors(run_id=run_id) + + # final vars are logged to be retrieved by the ``.view_run()`` method + for final_var in final_vars: + self._log_event( + run_id=run_id, + node_name=final_var, + task_id=None, + actor="adapter", + event_type=CachingEventType.IS_FINAL_VAR, + ) + + if inputs: + for node_name, value in inputs.items(): + self._process_input(run_id, node_name, value) + + if overrides: + for node_name, value in overrides.items(): + self._process_override(run_id, node_name, value) + + def pre_node_execute( + self, + *, + run_id: str, + node_: hamilton.node.Node, + kwargs: Dict[str, Any], + task_id: Optional[str] = None, + **future_kwargs, + ): + """Before node execution or retrieval, create the cache_key and set it in memory. + The cache_key is created based on the node's code version and its dependencies' data versions. + + Collecting ``data_version`` for upstream dependencies requires handling special cases when + task-based execution is used: + - If the current node is ``COLLECT`` , the dependency annotated with ``Collect[]`` needs to + be versioned item by item instead of versioning the full container. This is because the + collect order is inconsistent. + - If the current node is ``INSIDE`` and the dependency is ``EXPAND``, this means the + ``kwargs`` dictionary contains a single item. We need to version this individual item because + it will not be available from "inside" the branch for some executors (multiprocessing, multithreading) + because they lose access to the data_versions of ``OUTSIDE`` nodes stored in ``self.data_versions``. + + """ + node_name = node_.name + node_kwargs = HamiltonCacheAdapter._resolve_default_parameter_values(node_, kwargs) + + if self.behaviors[run_id][node_name] == CachingBehavior.IGNORE: + return + + # won't need the cache_key for either result retrieval or storage + if self.behaviors[run_id][node_name] == CachingBehavior.DISABLE: + return + + node_role = self._get_node_role(run_id=run_id, node_name=node_name, task_id=task_id) + collected_name = ( + node_.collect_dependency if node_role == NodeRoleInTaskExecution.COLLECT else SENTINEL + ) + + dependencies_data_versions = {} + for dep_name, dep_value in node_kwargs.items(): + # resolve caching behaviors + if self.behaviors[run_id][dep_name] == CachingBehavior.IGNORE: + # setting the data_version to "" in the cache_key means that + # the value of the dependency appears constant to this node + dependencies_data_versions[dep_name] = "" + continue + elif self.behaviors[run_id][dep_name] == CachingBehavior.DISABLE: + # setting the data_version to "" with a random suffix in the + # cache_key means the current node will be a cache miss and forced to recompute + dependencies_data_versions[dep_name] = "" + f"_{uuid.uuid4()}" + continue + + # resolve NodeRoleInTaskExecution + if task_id is None: + dep_role = NodeRoleInTaskExecution.STANDARD + else: + # want to check if dependency is an EXPAND node. We must not pass the current `task_id` + dep_role = self._get_node_role( + run_id=run_id, node_name=dep_name, task_id="" + ) + + # if dep_role == NodeRoleInTaskExecution.STANDARD: + + if dep_name == collected_name: + # the collected value should be hashed based on the items, not the container + items_data_versions = [self.version_data(item, run_id=run_id) for item in dep_value] + dep_data_version = fingerprinting.hash_value(sorted(items_data_versions)) + + elif dep_role == NodeRoleInTaskExecution.EXPAND: + # if the dependency is `EXPAND`, the kwarg received is a single item yielded by the iterator + # rather than the full iterable. We must version it directly, similar to a top-level input + dep_data_version = self.version_data(dep_value, run_id=run_id) + + else: + tasks_data_versions = self._get_memory_data_version( + run_id=run_id, node_name=dep_name, task_id=None + ) + if tasks_data_versions is SENTINEL: + dep_data_version = self.version_data(dep_value, run_id=run_id) + elif isinstance(tasks_data_versions, dict): + dep_data_version = tasks_data_versions.get(task_id) + else: + dep_data_version = tasks_data_versions + + if dep_data_version == fingerprinting.UNHASHABLE: + # if the data version is unhashable, we need to set a random suffix to the cache_key + # to prevent the cache from thinking this value is constant, causing a cache hit. + dep_data_version = "" + f"_{uuid.uuid4()}" + + dependencies_data_versions[dep_name] = dep_data_version + + # create cache_key before execution; will be reused during and after execution + cache_key = create_cache_key( + node_name=node_name, + code_version=self.code_versions[run_id][node_name], + dependencies_data_versions=dependencies_data_versions, + ) + self._set_cache_key( + run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key + ) + + def do_node_execute( + self, + *, + run_id: str, + node_: hamilton.node.Node, + kwargs: Dict[str, Any], + task_id: Optional[str] = None, + **future_kwargs, + ): + """Try to retrieve stored result from previous executions or execute the node. + + Use the previously created cache_key to retrieve the data_version from memory or the metadata_store. + If data_version is retrieved try to retrieve the result. If it fails, execute the node. + Else, execute the node. + """ + node_name = node_.name + node_callable = node_.callable + node_kwargs = HamiltonCacheAdapter._resolve_default_parameter_values(node_, kwargs) + + if self.behaviors[run_id][node_name] in ( + CachingBehavior.DISABLE, + CachingBehavior.IGNORE, + CachingBehavior.RECOMPUTE, + ): + result = self._execute_node( + run_id=run_id, + node_name=node_name, + node_callable=node_callable, + node_kwargs=node_kwargs, + task_id=task_id, + ) + if self.behaviors[run_id][node_name] in ( + CachingBehavior.RECOMPUTE, + CachingBehavior.IGNORE, + ): + cache_key = self.get_cache_key(run_id=run_id, node_name=node_name, task_id=task_id) + + # nodes collected in `._data_loaders` return tuples of (result, metadata) + # where metadata often includes a timestamp. To ensure we provide a consistent + # `data_version` / hash, we only hash the result part of the materializer return + # value and discard the metadata. + if node_name in self._data_loaders[run_id] and isinstance(result, tuple): + result = result[0] + + data_version = self._version_data(node_name=node_name, run_id=run_id, result=result) + + # nodes collected in `._data_savers` return a dictionary of metadata + # this metadata often includes a timestamp, leading to an unstable hash. + # we do not version nor store the metadata. This node is executed for its + # external effect of saving a file + if node_name in self._data_savers[run_id]: + data_version = f"{node_name}__metadata" + + self._set_memory_metadata( + run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version + ) + self._set_stored_metadata( + run_id=run_id, + node_name=node_name, + task_id=task_id, + cache_key=cache_key, + data_version=data_version, + ) + + return result + + # cache_key is set in `pre_node_execute` + cache_key = self.get_cache_key(run_id=run_id, node_name=node_name, task_id=task_id) + # retrieve data version from memory or metadata_store + data_version = self.get_data_version( + run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key + ) + + need_to_compute_node = False + if data_version is SENTINEL: + # must execute: data_version not found in memory or in metadata_store + need_to_compute_node = True + elif data_version == fingerprinting.UNHASHABLE: + # must execute: the retrieved data_version is UNHASHABLE, therefore it isn't stored. + need_to_compute_node = True + elif self.result_store.exists(data_version) is False: + # must execute: data_version retrieved, but result store can't find result + need_to_compute_node = True + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="result_store", + event_type=CachingEventType.MISSING_RESULT, + value=data_version, + ) + else: + # try to retrieve: data_version retrieve, result store found result + try: + # successful retrieval: retrieve the result; potentially load using the DataLoader if e.g.,``@cache(format="json")`` + result = self.result_store.get(data_version=data_version) + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="result_store", + event_type=CachingEventType.GET_RESULT, + msg="hit", + value=data_version, + ) + # set the data_version previously retrieved (could be from memory or store) + self._set_memory_metadata( + run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version + ) + except ResultRetrievalError: + # failed retrieval: despite finding the result, probably failed loading data using DataLoader if e.g.,``@cache(format="json")`` + self.metadata_store.delete(cache_key=cache_key) + self.result_store.delete(data_version) + need_to_compute_node = True + + if need_to_compute_node is True: + result = self._execute_node( + run_id=run_id, + node_name=node_name, + node_callable=node_callable, + node_kwargs=node_kwargs, + task_id=task_id, + ) + + # nodes collected in `._data_loaders` return tuples of (result, metadata) + # where metadata often includes a timestamp. To ensure we provide a consistent + # `data_version` / hash, we only hash the result part of the materializer return + # value and discard the metadata. + if node_name in self._data_loaders[run_id] and isinstance(result, tuple): + result = result[0] + + data_version = self._version_data(node_name=node_name, run_id=run_id, result=result) + + # nodes collected in `._data_savers` return a dictionary of metadata + # this metadata often includes a timestamp, leading to an unstable hash. + # we do not version nor store the metadata. This node is executed for its + # external effect of saving a file + if node_name in self._data_savers[run_id]: + data_version = f"{node_name}__metadata" + + self._set_memory_metadata( + run_id=run_id, node_name=node_name, task_id=task_id, data_version=data_version + ) + self._set_stored_metadata( + run_id=run_id, + node_name=node_name, + task_id=task_id, + cache_key=cache_key, + data_version=data_version, + ) + + return result + + def post_node_execute( + self, + *, + run_id: str, + node_: hamilton.node.Node, + result: Optional[str], + success: bool = True, + error: Optional[Exception] = None, + task_id: Optional[str] = None, + **future_kwargs, + ): + """Get the cache_key and data_version stored in memory (respectively from + pre_node_execute and do_node_execute) and store the result in result_store + if it doesn't exist. + """ + node_name = node_.name + + if success is False: + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="adapter", + event_type=CachingEventType.FAILED_EXECUTION, + msg=f"{error}", + ) + return + + if self.behaviors[run_id][node_name] in ( + CachingBehavior.DEFAULT, + CachingBehavior.RECOMPUTE, + CachingBehavior.IGNORE, + ): + cache_key = self.get_cache_key(run_id=run_id, node_name=node_name, task_id=task_id) + data_version = self.get_data_version( + run_id=run_id, node_name=node_name, task_id=task_id, cache_key=cache_key + ) + assert data_version is not SENTINEL + + # TODO clean up this logic + # check if a materialized file exist before writing results + # when using `@cache(format="json")` + cache_format = ( + self._fn_graphs[run_id] + .nodes[node_name] + .tags.get(cache_decorator.FORMAT_KEY, SENTINEL) + ) + if cache_format is not SENTINEL: + saver_cls, loader_cls = search_data_adapter_registry( + name=cache_format, type_=type(result) + ) # type: ignore + materialized_path = self.result_store._materialized_path(data_version, saver_cls) + materialized_path_missing = not materialized_path.exists() + else: + saver_cls, loader_cls = None, None + materialized_path_missing = False + + result_missing = not self.result_store.exists(data_version) + if result_missing or materialized_path_missing: + self.result_store.set( + data_version=data_version, + result=result, + saver_cls=saver_cls, + loader_cls=loader_cls, + ) + self._log_event( + run_id=run_id, + node_name=node_name, + task_id=task_id, + actor="result_store", + event_type=CachingEventType.SET_RESULT, + value=data_version, + ) diff --git a/hamilton/caching/cache_key.py b/hamilton/caching/cache_key.py new file mode 100644 index 000000000..dbd14d744 --- /dev/null +++ b/hamilton/caching/cache_key.py @@ -0,0 +1,53 @@ +import base64 +import zlib +from typing import Dict, Mapping + + +def _compress_string(string: str) -> str: + return base64.b64encode(zlib.compress(string.encode(), level=3)).decode() + + +def _decompress_string(string: str) -> str: + return zlib.decompress(base64.b64decode(string.encode())).decode() + + +def _encode_str_dict(d: Mapping) -> str: + interleaved_tuple = tuple(item for pair in sorted(d.items()) for item in pair) + return ",".join(interleaved_tuple) + + +def _decode_str_dict(s: str) -> Mapping: + interleaved_tuple = tuple(s.split(",")) + d = {} + for i in range(0, len(interleaved_tuple), 2): + d[interleaved_tuple[i]] = interleaved_tuple[i + 1] + return d + + +def decode_key(cache_key: str) -> dict: + node_name, _, code_and_data_string = cache_key.partition("-") + code_version, _, dep_encoded = code_and_data_string.partition("-") + data_stringified = _decompress_string(dep_encoded) + + if data_stringified == "": + dependencies_data_versions = {} + else: + dependencies_data_versions = _decode_str_dict(data_stringified) + return dict( + node_name=node_name, + code_version=code_version, + dependencies_data_versions=dependencies_data_versions, + ) + + +def create_cache_key( + node_name: str, code_version: str, dependencies_data_versions: Dict[str, str] +) -> str: + if len(dependencies_data_versions.keys()) > 0: + dependencies_stringified = _encode_str_dict(dependencies_data_versions) + else: + dependencies_stringified = "" + + safe_node_name = "".join(c for c in node_name if c.isalnum() or c in ("_",)).rstrip() + + return f"{safe_node_name}-{code_version}-{_compress_string(dependencies_stringified)}" diff --git a/hamilton/caching/fingerprinting.py b/hamilton/caching/fingerprinting.py new file mode 100644 index 000000000..1e1ee0ae0 --- /dev/null +++ b/hamilton/caching/fingerprinting.py @@ -0,0 +1,267 @@ +""" +This module contains hashing functions for Python objects. It uses +functools.singledispatch to allow specialized implementations based on type. +Singledispatch automatically applies the most specific implementation + +This module houses implementations for the Python standard library. Supporting +all types is considerable endeavor, so we'll add support as types are requested +by users. + +Otherwise, 3rd party types can be supported via the `h_databackends` module. +This registers abstract types that can be checked without having to import the +3rd party library. For instance, there are implementations for pandas.DataFrame +and polars.DataFrame despite these libraries not being imported here. + +IMPORTANT all container types that make a recursive call to `hash_value` or a specific +implementation should pass the `depth` parameter to prevent `RecursionError`. +""" + +import base64 +import datetime +import functools +import hashlib +import logging +import sys +from collections.abc import Mapping, Sequence, Set +from typing import Dict + +from hamilton.experimental import h_databackends + +# NoneType is introduced in Python 3.10 +try: + from types import NoneType +except ImportError: + NoneType = type(None) + + +logger = logging.getLogger("hamilton.caching") + + +MAX_DEPTH = 6 +UNHASHABLE = "" +NONE_HASH = "" + + +def set_max_depth(depth: int) -> None: + """Set the maximum recursion depth for fingerprinting non-supported types. + + :param depth: The maximum depth for fingerprinting. + """ + global MAX_DEPTH + MAX_DEPTH = depth + + +def _compact_hash(digest: bytes) -> str: + """Compact the hash to a string that's safe to pass around. + + NOTE this is particularly relevant for the Hamilton UI and + passing hashes/fingerprints through web services. + """ + return base64.urlsafe_b64encode(digest).decode() + + +@functools.singledispatch +def hash_value(obj, *args, depth=0, **kwargs) -> str: + """Fingerprinting strategy that computes a hash of the + full Python object. + + The default case hashes the `__dict__` attribute of the + object (recursive). + """ + if depth > MAX_DEPTH: + return UNHASHABLE + + if hasattr(obj, "__dict__"): + return hash_value(obj.__dict__, depth=depth + 1) + + # check if the object comes from a module part of the standard library + # if it's the case, hash it's __repr__(), which is a string representation of the object + # __repr__() from the standard library should be well-formed and offer a reliable basis + # for fingerprinting. + # for example, this will catch: pathlib.Path, enum.Enum, argparse.Namespace + elif getattr(obj, "__module__", False): + if obj.__module__.partition(".")[0] in sys.builtin_module_names: + return hash_repr(obj, depth=depth) + + # cover the datetime module, which doesn't have a __module__ attribute + elif type(obj) in vars(datetime).values(): + return hash_repr(obj, depth=depth) + + return UNHASHABLE + + +@hash_value.register(NoneType) +def hash_none(obj, *args, **kwargs) -> str: + """Hash for None is + + Primitive type returns a hash and doesn't have to handle depth. + """ + return NONE_HASH + + +def hash_repr(obj, *args, **kwargs) -> str: + """Use the built-in repr() to get a string representation of the object + and hash it. + + While `.__repr__()` might not be implemented for all classes, the function + `repr()` will handle it, along with exceptions, to always return a value. + + Primitive type returns a hash and doesn't have to handle depth. + """ + return hash_primitive(repr(obj)) + + +# we need to use explicit multiple registration because older Python +# versions don't support type annotations with Union types +@hash_value.register(str) +@hash_value.register(int) +@hash_value.register(float) +@hash_value.register(bool) +def hash_primitive(obj, *args, **kwargs) -> str: + """Convert the primitive to a string and hash it + + Primitive type returns a hash and doesn't have to handle depth. + """ + hash_object = hashlib.md5(str(obj).encode()) + return _compact_hash(hash_object.digest()) + + +@hash_value.register(bytes) +def hash_bytes(obj, *args, **kwargs) -> str: + """Convert the primitive to a string and hash it + + Primitive type returns a hash and doesn't have to handle depth. + """ + hash_object = hashlib.md5(obj) + return _compact_hash(hash_object.digest()) + + +@hash_value.register(Sequence) +def hash_sequence(obj, *args, depth: int = 0, **kwargs) -> str: + """Hash each object of the sequence. + + Orders matters for the hash since orders matters in a sequence. + """ + hash_object = hashlib.sha224() + for elem in obj: + hash_object.update(hash_value(elem, depth=depth + 1).encode()) + + return _compact_hash(hash_object.digest()) + + +def hash_unordered_mapping(obj, *args, depth: int = 0, **kwargs) -> str: + """ + + When hashing an unordered mapping, the two following dict have the same hash. + + .. code-block:: python + + foo = {"key": 3, "key2": 13} + bar = {"key2": 13, "key": 3} + + hash_mapping(foo) == hash_mapping(bar) + """ + + hashed_mapping: Dict[str, str] = {} + for key, value in obj.items(): + hashed_mapping[hash_value(key, depth=depth + 1)] = hash_value(value, depth=depth + 1) + + hash_object = hashlib.sha224() + for key, value in sorted(hashed_mapping.items()): + hash_object.update(key.encode()) + hash_object.update(value.encode()) + + return _compact_hash(hash_object.digest()) + + +@hash_value.register(Mapping) +def hash_mapping(obj, *, ignore_order: bool = True, depth: int = 0, **kwargs) -> str: + """Hash each key then its value. + + The mapping is always sorted first because order shouldn't matter + in a mapping. + + NOTE Since Python 3.7, dictionary store insertion order. However, this + function assumes that they key order doesn't matter to uniquely identify + the dictionary. + + .. code-block:: python + + foo = {"key": 3, "key2": 13} + bar = {"key2": 13, "key": 3} + + hash_mapping(foo) == hash_mapping(bar) + + """ + if ignore_order: + # use the same depth because we're simply dispatching to another implementation + return hash_unordered_mapping(obj, depth=depth) + + hash_object = hashlib.sha224() + for key, value in obj.items(): + hash_object.update(hash_value(key, depth=depth + 1).encode()) + hash_object.update(hash_value(value, depth=depth + 1).encode()) + + return _compact_hash(hash_object.digest()) + + +@hash_value.register(Set) +def hash_set(obj, *args, depth: int = 0, **kwargs) -> str: + """Hash each element of the set, then sort hashes, and + create a hash of hashes. + + For the same objects in the set, the hashes will be the + same. + """ + hashes = [hash_value(elem, depth=depth + 1) for elem in obj] + sorted_hashes = sorted(hashes) + + hash_object = hashlib.sha224() + for hash in sorted_hashes: + hash_object.update(hash.encode()) + + return _compact_hash(hash_object.digest()) + + +@hash_value.register(h_databackends.AbstractPandasDataFrame) +@hash_value.register(h_databackends.AbstractPandasColumn) +def hash_pandas_obj(obj, *args, depth: int = 0, **kwargs) -> str: + """Convert a pandas dataframe, series, or index to + a dictionary of {index: row_hash} then hash it. + + Given the hashing for mappings, the physical ordering or rows doesn't matter. + For example, if the index is a date, the hash will represent the {date: row_hash}, + and won't preserve how dates were ordered in the DataFrame. + """ + from pandas.util import hash_pandas_object + + hash_per_row = hash_pandas_object(obj) + return hash_mapping(hash_per_row.to_dict(), ignore_order=False, depth=depth + 1) + + +@hash_value.register(h_databackends.AbstractPolarsDataFrame) +def hash_polars_dataframe(obj, *args, depth: int = 0, **kwargs) -> str: + """Convert a polars dataframe, series, or index to + a list of hashes then hash it. + """ + hash_per_row = obj.hash_rows() + return hash_sequence(hash_per_row.to_list(), depth=depth + 1) + + +@hash_value.register(h_databackends.AbstractPolarsColumn) +def hash_polars_column(obj, *args, depth: int = 0, **kwargs) -> str: + """Promote the single Series to a dataframe and hash it""" + # use the same depth because we're simply dispatching to another implementation + return hash_polars_dataframe(obj.to_frame(), depth=depth) + + +@hash_value.register(h_databackends.AbstractNumpyArray) +def hash_numpy_array(obj, *args, depth: int = 0, **kwargs) -> str: + """Get the bytes representation of the array raw data and hash it. + + Might not be ideal because different higher-level numpy objects could have + the same underlying array representation (e.g., masked arrays). + Unsure, but it's an area to investigate. + """ + # use the same depth because we're simply dispatching to another implementation + return hash_bytes(obj.tobytes(), depth=depth) diff --git a/hamilton/caching/stores/__init__.py b/hamilton/caching/stores/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/hamilton/caching/stores/base.py b/hamilton/caching/stores/base.py new file mode 100644 index 000000000..901edb412 --- /dev/null +++ b/hamilton/caching/stores/base.py @@ -0,0 +1,220 @@ +import abc +import pickle +from datetime import datetime, timedelta, timezone +from typing import Any, Optional, Sequence, Tuple, Type + +from hamilton.htypes import custom_subclass_check +from hamilton.io.data_adapters import DataLoader, DataSaver +from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY + + +class ResultRetrievalError(Exception): + """Raised by the SmartCacheAdapter when ResultStore.get() fails.""" + + +# TODO Currently, this check is done when data needs to be saved. +# Ideally, it would be done earlier in the caching lifecycle. +def search_data_adapter_registry( + name: str, type_: type +) -> Tuple[Type[DataSaver], Type[DataLoader]]: + """Find pair of DataSaver and DataLoader registered with `name` and supporting `type_`""" + if name not in SAVER_REGISTRY or name not in LOADER_REGISTRY: + raise KeyError( + f"{name} isn't associated to both a DataLoader and a DataSaver. " + "Default saver/loader pairs include `json`, `file`, `pickle`, `parquet`, `csv`, " + "`feather`, `orc`, `excel`. More pairs may be available through plugins." + ) + + try: + saver_cls = next( + saver_cls + for saver_cls in SAVER_REGISTRY[name] + if any( + custom_subclass_check(type_, applicable_type) + for applicable_type in saver_cls.applicable_types() + ) + ) + except StopIteration as e: + raise KeyError(f"{name} doesn't have any DataSaver supporting type {type_}") from e + + try: + loader_cls = next( + loader_cls + for loader_cls in LOADER_REGISTRY[name] + if any( + custom_subclass_check(type_, applicable_type) + for applicable_type in loader_cls.applicable_types() + ) + ) + except StopIteration as e: + raise KeyError(f"{name} doesn't have any DataLoader supporting type {type_}") from e + + return saver_cls, loader_cls + + +class ResultStore(abc.ABC): + @abc.abstractmethod + def set(self, data_version: str, result: Any, **kwargs) -> None: + """Store ``result`` keyed by ``data_version``.""" + + @abc.abstractmethod + def get(self, data_version: str, **kwargs) -> Optional[Any]: + """Try to retrieve ``result`` keyed by ``data_version``. + If retrieval misses, return ``None``. + """ + + @abc.abstractmethod + def delete(self, data_version: str) -> None: + """Delete ``result`` keyed by ``data_version``.""" + + @abc.abstractmethod + def delete_all(self) -> None: + """Delete all stored results.""" + + @abc.abstractmethod + def exists(self, data_version: str) -> bool: + """boolean check if a ``result`` is found for ``data_version`` + If True, ``.get()`` should successfully retrieve the ``result``. + """ + + +class MetadataStore(abc.ABC): + @abc.abstractmethod + def __len__(self) -> int: + """Return the number of cache_keys in the metadata store""" + + @abc.abstractmethod + def initialize(self, run_id: str) -> None: + """Setup the metadata store and log the start of the run""" + + @abc.abstractmethod + def set(self, cache_key: str, data_version: str, **kwargs) -> Optional[Any]: + """Store the mapping ``cache_key -> data_version``. + Can include other metadata (e.g., node name, run id, code version) depending + on the implementation. + """ + + @abc.abstractmethod + def get(self, cache_key: str) -> Optional[str]: + """Try to retrieve ``data_version`` keyed by ``cache_key``. + If retrieval misses return ``None``. + """ + + @abc.abstractmethod + def delete(self, cache_key: str) -> None: + """Delete ``data_version`` keyed by ``cache_key``.""" + + @abc.abstractmethod + def delete_all(self) -> None: + """Delete all stored metadata.""" + + @abc.abstractmethod + def exists(self, cache_key: str) -> bool: + """boolean check if a ``data_version`` is found for ``cache_key`` + If True, ``.get()`` should successfully retrieve the ``data_version``. + """ + + @abc.abstractmethod + def get_run_ids(self) -> Sequence[str]: + """Return a list of run ids, sorted from oldest to newest start time. + A ``run_id`` is registered when the metadata_store ``.initialize()`` is called. + + NOTE because of race conditions, the order could theoretically differ from the + order stored on the SmartCacheAdapter `._run_ids` attribute. + """ + + @abc.abstractmethod + def get_run(self, run_id: str) -> Any: + """Return all the metadata associated with a run. + The metadata content may differ across MetadataStore implementations + """ + + @property + def size(self) -> int: + """Number of unique entries (i.e., cache_keys) in the metadata_store""" + return self.__len__() + + @property + def last_run_id(self) -> str: + """Return""" + return self.get_run_ids()[-1] + + def get_last_run(self) -> Any: + """Return the metadata from the last started run.""" + return self.get_run(self.last_run_id) + + +# TODO refactor the association between StoredResult, MetadataStore, and ResultStore +# to load data using the `DataLoader` class and kwargs instead of pickling the instantiated +# DataLoader object. This would be safer across Hamilton versions. +class StoredResult: + def __init__( + self, + value: Any, + expires_at=None, + saver=None, + loader=None, + ): + self.value = value + self.expires_at = expires_at + self.saver = saver + self.loader = loader + + @classmethod + def new( + cls, + value: Any, + expires_in: Optional[timedelta] = None, + saver: Optional[DataSaver] = None, + loader: Optional[DataLoader] = None, + ) -> "StoredResult": + if expires_in is not None and not isinstance(expires_in, timedelta): + expires_in = timedelta(seconds=expires_in) + + # != operator on boolean is XOR + if bool(saver is not None) != bool(loader is not None): + raise ValueError( + "Must pass both `saver` and `loader` or neither. Currently received: " + f"`saver`: `{saver}`; `loader`: `{loader}`" + ) + + return cls( + value=value, + expires_at=(datetime.now(tz=timezone.utc) + expires_in) if expires_in else None, + saver=saver, + loader=loader, + ) + + @property + def expired(self) -> bool: + return self.expires_at is not None and datetime.now(tz=timezone.utc) >= self.expires_at + + @property + def expires_in(self) -> int: + if self.expires_at: + return int(self.expires_at.timestamp() - datetime.now(tz=timezone.utc).timestamp()) + + return -1 + + def save(self) -> bytes: + """Receives pickleable data or DataLoader to use to load the real data""" + if self.saver is not None: + self.saver.save_data(data=self.value) + to_pickle = self.loader + else: + to_pickle = self.value + + return pickle.dumps(to_pickle) + + @classmethod + def load(cls, raw: bytes) -> "StoredResult": + """Reads the raw bytes from disk and sets `StoredResult.data`""" + loaded = pickle.loads(raw) + if isinstance(loaded, DataLoader): + loader = loaded + result, metadata = loader.load_data(None) + else: + loader = None + result = loaded + + return StoredResult.new(value=result) diff --git a/hamilton/caching/stores/file.py b/hamilton/caching/stores/file.py new file mode 100644 index 000000000..d483ca292 --- /dev/null +++ b/hamilton/caching/stores/file.py @@ -0,0 +1,89 @@ +import shutil +from pathlib import Path +from typing import Any, Optional + +from hamilton.caching.stores.base import ResultStore, StoredResult +from hamilton.io.data_adapters import DataLoader, DataSaver + + +class FileResultStore(ResultStore): + def __init__(self, path: str, create_dir: bool = True) -> None: + self.path = Path(path) + self.create_dir = create_dir + + if self.create_dir: + self.path.mkdir(exist_ok=True, parents=True) + + @staticmethod + def _write_result(file_path: Path, stored_result: StoredResult) -> None: + file_path.write_bytes(stored_result.save()) + + @staticmethod + def _load_result_from_path(path: Path) -> Optional[StoredResult]: + try: + data = path.read_bytes() + return StoredResult.load(data) + except FileNotFoundError: + return None + + def _path_from_data_version(self, data_version: str) -> Path: + return self.path.joinpath(data_version) + + def _materialized_path(self, data_version: str, saver_cls: DataSaver) -> Path: + # TODO allow a more flexible mechanism to specify file path extension + return self._path_from_data_version(data_version).with_suffix(f".{saver_cls.name()}") + + def exists(self, data_version: str) -> bool: + result_path = self._path_from_data_version(data_version) + return result_path.exists() + + def set( + self, + data_version: str, + result: Any, + saver_cls: Optional[DataSaver] = None, + loader_cls: Optional[DataLoader] = None, + ) -> None: + # != operator on boolean is XOR + if bool(saver_cls is not None) != bool(loader_cls is not None): + raise ValueError( + "Must pass both `saver` and `loader` or neither. Currently received: " + f"`saver`: `{saver_cls}`; `loader`: `{loader_cls}`" + ) + + if saver_cls is not None: + # materialized_path + materialized_path = self._materialized_path(data_version, saver_cls) + saver = saver_cls(path=str(materialized_path.absolute())) + loader = loader_cls(path=str(materialized_path.absolute())) + else: + saver = None + loader = None + + self.path.mkdir(exist_ok=True) + result_path = self._path_from_data_version(data_version) + stored_result = StoredResult.new(value=result, saver=saver, loader=loader) + self._write_result(result_path, stored_result) + + def get(self, data_version: str) -> Optional[Any]: + result_path = self._path_from_data_version(data_version) + stored_result = self._load_result_from_path(result_path) + + if stored_result is None: + return None + + return stored_result.value + + def delete(self, data_version: str) -> None: + result_path = self._path_from_data_version(data_version) + result_path.unlink(missing_ok=True) + + def delete_all(self) -> None: + shutil.rmtree(self.path) + self.path.mkdir(exist_ok=True) + + def delete_expired(self) -> None: + for file_path in self.path.iterdir(): + stored_result = self._load_result_from_path(file_path) + if stored_result and stored_result.expired: + file_path.unlink(missing_ok=True) diff --git a/hamilton/caching/stores/sqlite.py b/hamilton/caching/stores/sqlite.py new file mode 100644 index 000000000..c234024bd --- /dev/null +++ b/hamilton/caching/stores/sqlite.py @@ -0,0 +1,204 @@ +import pathlib +import sqlite3 +import threading +from typing import List, Optional + +from hamilton.caching.stores.base import MetadataStore + + +class SQLiteMetadataStore(MetadataStore): + def __init__( + self, + path: str, + connection_kwargs: Optional[dict] = None, + ) -> None: + self._directory = pathlib.Path(path).resolve() + self._directory.mkdir(parents=True, exist_ok=True) + self._path = self._directory.joinpath("metadata_store").with_suffix(".db") + self.connection_kwargs: dict = connection_kwargs if connection_kwargs else {} + + self._thread_local = threading.local() + + def _get_connection(self): + if not hasattr(self._thread_local, "connection"): + self._thread_local.connection = sqlite3.connect( + str(self._path), check_same_thread=False, **self.connection_kwargs + ) + return self._thread_local.connection + + def _close_connection(self): + if hasattr(self._thread_local, "connection"): + self._thread_local.connection.close() + del self._thread_local.connection + + @property + def connection(self): + return self._get_connection() + + def __del__(self): + """Close the SQLite connection when the object is deleted""" + self._close_connection() + + def _create_tables_if_not_exists(self): + """Create the tables necessary for the cache: + + run_ids: queue of run_ids, ordered by start time. + history: queue of executed node; allows to query "latest" execution of a node + cache_metadata: information to determine if a node needs to be computed or not + + In the table ``cache_metadata``, the ``cache_key`` is unique whereas + ``history`` allows duplicate. + """ + cur = self.connection.cursor() + + cur.execute( + """\ + CREATE TABLE IF NOT EXISTS run_ids ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + cur.execute( + """\ + CREATE TABLE IF NOT EXISTS history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + cache_key TEXT, + run_id TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + FOREIGN KEY (cache_key) REFERENCES cache_metadata(cache_key) + ); + """ + ) + cur.execute( + """\ + CREATE TABLE IF NOT EXISTS cache_metadata ( + cache_key TEXT PRIMARY KEY, + node_name TEXT NOT NULL, + code_version TEXT NOT NULL, + data_version TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + FOREIGN KEY (cache_key) REFERENCES history(cache_key) + ); + """ + ) + self.connection.commit() + + def initialize(self, run_id) -> None: + """Call initialize when starting a run. This will create database tables + if necessary. + """ + self._create_tables_if_not_exists() + cur = self.connection.cursor() + cur.execute("INSERT INTO run_ids (run_id) VALUES (?)", (run_id,)) + self.connection.commit() + + def __len__(self): + """Number of entries in cache_metadata""" + cur = self.connection.cursor() + cur.execute("SELECT COUNT(*) FROM cache_metadata") + return cur.fetchone()[0] + + def set( + self, + *, + cache_key: str, + node_name: str, + code_version: str, + data_version: str, + run_id: str, + ) -> None: + cur = self.connection.cursor() + + cur.execute("INSERT INTO history (cache_key, run_id) VALUES (?, ?)", (cache_key, run_id)) + cur.execute( + """\ + INSERT OR IGNORE INTO cache_metadata ( + cache_key, node_name, code_version, data_version + ) VALUES (?, ?, ?, ?) + """, + (cache_key, node_name, code_version, data_version), + ) + + self.connection.commit() + + def get(self, cache_key: str) -> Optional[str]: + cur = self.connection.cursor() + cur.execute( + """\ + SELECT data_version + FROM cache_metadata + WHERE cache_key = ? + """, + (cache_key,), + ) + result = cur.fetchone() + + if result is None: + data_version = None + else: + data_version = result[0] + + return data_version + + def delete(self, cache_key: str) -> None: + """Delete metadata associated with ``cache_key``.""" + cur = self.connection.cursor() + cur.execute("DELETE FROM cache_metadata WHERE cache_key = ?", (cache_key,)) + self.connection.commit() + + def delete_all(self): + """Delete all existing tables from the database""" + cur = self.connection.cursor() + + for table_name in ["run_ids", "history", "cache_metadata"]: + cur.execute(f"DROP TABLE IF EXISTS {table_name};") + + self.connection.commit() + + def exists(self, cache_key: str) -> bool: + """boolean check if a ``data_version`` is found for ``cache_key`` + If True, ``.get()`` should successfully retrieve the ``data_version``. + """ + cur = self.connection.cursor() + cur.execute("SELECT cache_key FROM cache_metadata WHERE cache_key = ?", (cache_key,)) + result = cur.fetchone() + + return result is not None + + def get_run_ids(self) -> List[str]: + cur = self.connection.cursor() + cur.execute("SELECT run_id FROM history ORDER BY id") + result = cur.fetchall() + + if result is None: + raise IndexError("No `run_id` found. Table `history` is empty.") + + return result[0] + + def get_run(self, run_id: str) -> List[dict]: + """Return all the metadata associated with a run.""" + cur = self.connection.cursor() + cur.execute( + """\ + SELECT + cache_metadata.node_name, + cache_metadata.code_version, + cache_metadata.data_version + FROM (SELECT * FROM history WHERE history.run_id = ?) AS run_history + JOIN cache_metadata ON run_history.cache_key = cache_metadata.cache_key + """, + (run_id,), + ) + results = cur.fetchall() + + if results is None: + raise IndexError(f"`run_id` not found in table `history`: {run_id}") + + return [ + dict(node_name=node_name, code_version=code_version, data_version=data_version) + for node_name, code_version, data_version in results + ] diff --git a/hamilton/caching/stores/utils.py b/hamilton/caching/stores/utils.py new file mode 100644 index 000000000..1a69b41eb --- /dev/null +++ b/hamilton/caching/stores/utils.py @@ -0,0 +1,21 @@ +import pathlib + + +def get_directory_size(directory: str) -> float: + total_size = 0 + for p in pathlib.Path(directory).rglob("*"): + if p.is_file(): + total_size += p.stat().st_size + + return total_size + + +def readable_bytes_size(n_bytes: float) -> str: + labels = ["B", "KB", "MB", "GB", "TB"] + exponent = 0 + + while n_bytes > 1024.0: + n_bytes /= 1024.0 + exponent += 1 + + return f"{n_bytes:.2f} {labels[exponent]}" diff --git a/hamilton/driver.py b/hamilton/driver.py index 287e29258..2e2c6beb7 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -5,20 +5,34 @@ import json import logging import operator +import pathlib import sys import time # required if we want to run this code stand alone. import typing import uuid -from collections.abc import Sequence # typing.Sequence is deprecated in >=3.9 from datetime import datetime from types import ModuleType -from typing import Any, Callable, Collection, Dict, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Union, +) import pandas as pd from hamilton import common, graph_types, htypes +from hamilton.caching.adapter import HamiltonCacheAdapter +from hamilton.caching.stores.base import MetadataStore, ResultStore from hamilton.dev_utils import deprecation from hamilton.execution import executors, graph_functions, grouping, state from hamilton.graph_types import HamiltonNode @@ -1818,6 +1832,18 @@ def validate_materialization( all_nodes = nodes | user_nodes self.graph_executor.validate(list(all_nodes)) + @property + def cache(self) -> HamiltonCacheAdapter: + """Directly access the cache adapter""" + if self.adapter: + for adapter in self.adapter.adapters: + if isinstance(adapter, HamiltonCacheAdapter): + return adapter + else: + raise KeyError( + "Cache not yet set. Add a cache by using ``Builder().with_cache()`` when building the ``Driver``." + ) + class Builder: def __init__(self): @@ -1905,6 +1931,11 @@ def with_adapters(self, *adapters: lifecycle_base.LifecycleAdapter) -> "Builder" :param adapter: Adapter to use. :return: self """ + if any(isinstance(adapter, HamiltonCacheAdapter) for adapter in adapters): + self._require_field_unset( + "cache", "Cannot use `.with_cache()` or with `.with_adapters(SmartCacheAdapter())`." + ) + self.adapters.extend(adapters) return self @@ -1932,6 +1963,89 @@ def with_materializers( self.materializers.extend(materializers) return self + def with_cache( + self, + path: Union[str, pathlib.Path] = ".hamilton_cache", + metadata_store: Optional[MetadataStore] = None, + result_store: Optional[ResultStore] = None, + default: Optional[Union[Literal[True], Sequence[str]]] = None, + recompute: Optional[Union[Literal[True], Sequence[str]]] = None, + ignore: Optional[Union[Literal[True], Sequence[str]]] = None, + disable: Optional[Union[Literal[True], Sequence[str]]] = None, + default_behavior: Literal["default", "recompute", "disable", "ignore"] = "default", + default_loader_behavior: Literal["default", "recompute", "disable", "ignore"] = "default", + default_saver_behavior: Literal["default", "recompute", "disable", "ignore"] = "default", + log_to_file: bool = False, + ) -> "Builder": + """Add the caching adapter to the `Driver` + + :param path: path where the cache metadata and results will be stored + :param metadata_store: BaseStore handling metadata for the cache adapter + :param result_store: BaseStore caching dataflow execution results + :param default: Set caching behavior to DEFAULT for specified node names. If True, apply to all nodes. + :param recompute: Set caching behavior to RECOMPUTE for specified node names. If True, apply to all nodes. + :param ignore: Set caching behavior to IGNORE for specified node names. If True, apply to all nodes. + :param disable: Set caching behavior to DISABLE for specified node names. If True, apply to all nodes. + :param default_behavior: Set the default caching behavior. + :param default_loader_behavior: Set the default caching behavior `DataLoader` nodes. + :param default_saver_behavior: Set the default caching behavior `DataSaver` nodes. + :log_to_file: If True, the cache adapter logs will be stored in JSONL format under the metadata_store directory + :return: self + + + Learn more on the :doc:`/concepts/caching` Concepts page. + + .. code-block:: python + + from hamilton import driver + import my_dataflow + + dr = ( + driver.Builder() + .with_module(my_dataflow) + .with_cache() + .build() + ) + + # execute twice + dr.execute([...]) + dr.execute([...]) + + # view cache logs + dr.cache.logs() + + """ + self._require_field_unset( + "cache", "Cannot use `.with_cache()` or with `.with_adapters(SmartCacheAdapter())`." + ) + adapter = HamiltonCacheAdapter( + path=path, + metadata_store=metadata_store, + result_store=result_store, + default=default, + recompute=recompute, + ignore=ignore, + disable=disable, + default_behavior=default_behavior, + default_loader_behavior=default_loader_behavior, + default_saver_behavior=default_saver_behavior, + log_to_file=log_to_file, + ) + self.adapters.append(adapter) + return self + + @property + def cache(self) -> Optional[HamiltonCacheAdapter]: + """Attribute to check if a cache was set, either via `.with_cache()` or + `.with_adapters(SmartCacheAdapter())` + + Required for the check `._require_field_unset()` + """ + if self.adapters: + for adapter in self.adapters: + if isinstance(adapter, HamiltonCacheAdapter): + return adapter + def with_execution_manager(self, execution_manager: executors.ExecutionManager) -> "Builder": """Sets the execution manager to use. Note that this cannot be used if local_executor or remote_executor are also set diff --git a/hamilton/experimental/h_cache.py b/hamilton/experimental/h_cache.py index 5fe7a31e9..b5a970db1 100644 --- a/hamilton/experimental/h_cache.py +++ b/hamilton/experimental/h_cache.py @@ -12,6 +12,16 @@ logger = logging.getLogger(__name__) + +logger.warning( + "The module `hamilton.experimental.h_cache` and the class `CachingGraphAdapter `" + "are deprecated and will be removed in Hamilton 2.0. " + "Consider enabling the core caching feature via `Builder.with_cache()`. " + "This might not be 1-to-1 replacement, so please reach out if there are missing features. " + "See https://hamilton.dagworks.io/en/latest/concepts/caching/ to learn more." +) + + """ Base SERDE functions. diff --git a/hamilton/experimental/h_databackends.py b/hamilton/experimental/h_databackends.py index 288100e8e..3726bc507 100644 --- a/hamilton/experimental/h_databackends.py +++ b/hamilton/experimental/h_databackends.py @@ -130,6 +130,11 @@ class AbstractModinDataFrame(AbstractBackend): _backends = [("modin.pandas", "DataFrame")] +# numpy +class AbstractNumpyArray(AbstractBackend): + _backends = [("numpy", "ndarray")] + + def register_backends() -> Tuple[Tuple[type], Tuple[type]]: """Register databackends defined in this module that include `DataFrame` and `Column` in their class name diff --git a/hamilton/function_modifiers/__init__.py b/hamilton/function_modifiers/__init__.py index 9a7d12fe5..333f103f4 100644 --- a/hamilton/function_modifiers/__init__.py +++ b/hamilton/function_modifiers/__init__.py @@ -73,6 +73,7 @@ tag = metadata.tag tag_outputs = metadata.tag_outputs schema = metadata.schema +cache = metadata.cache # data quality + associated tags check_output = validation.check_output diff --git a/hamilton/function_modifiers/adapters.py b/hamilton/function_modifiers/adapters.py index 82845dc9b..2f975f49f 100644 --- a/hamilton/function_modifiers/adapters.py +++ b/hamilton/function_modifiers/adapters.py @@ -772,9 +772,9 @@ def generate_nodes(self, fn: Callable, config) -> List[node.Node]: { "hamilton.data_loader": True, "hamilton.data_loader.has_metadata": True, - "hamilton.data_loader.source": f"{fn.__name__}", + "hamilton.data_loader.node": f"{fn.__name__}", "hamilton.data_loader.classname": f"{fn.__name__}()", - "hamilton.data_loader.node": _name, + "hamilton.data_loader.source": _name, } ) @@ -790,9 +790,9 @@ def filter_function(**kwargs): tags={ "hamilton.data_loader": True, "hamilton.data_loader.has_metadata": False, - "hamilton.data_loader.source": f"{fn.__name__}", + "hamilton.data_loader.node": f"{fn.__name__}", "hamilton.data_loader.classname": f"{fn.__name__}()", - "hamilton.data_loader.node": fn.__name__, + "hamilton.data_loader.source": fn.__name__, }, ) diff --git a/hamilton/function_modifiers/metadata.py b/hamilton/function_modifiers/metadata.py index cc9e02dd1..c874be6b6 100644 --- a/hamilton/function_modifiers/metadata.py +++ b/hamilton/function_modifiers/metadata.py @@ -1,7 +1,7 @@ """Decorators that attach metadata to nodes""" import json -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from hamilton import htypes, node, registry from hamilton.function_modifiers import base @@ -315,3 +315,102 @@ def ray_remote_options(**kwargs: Union[int, Dict[str, int]]) -> RayRemote: def example() -> pd.DataFrame: ... """ return RayRemote(**kwargs) + + +# materializers that have a `path` kwarg and are part of the core Hamilton library +# parquet, csv, feather, orc, and excel are via the pandas extension because it's currently a Hamilton dependency +CACHE_MATERIALIZERS = Literal[ + "json", + "file", + "pickle", + "parquet", + "csv", + "feather", + "orc", + "excel", +] + +# see hamilton.caching.adapter.CachingBehavior enum for details. +# default: caching is enabled +# recompute: always compute the node instead of retrieving +# ignore: the data version won't be part of downstream keys +# disable: act as if caching wasn't enabled. +CACHE_BEHAVIORS = Literal["default", "recompute", "ignore", "disable"] + + +class cache(base.NodeDecorator): + BEHAVIOR_KEY = "cache.behavior" + FORMAT_KEY = "cache.format" + + def __init__( + self, + *, + behavior: Optional[CACHE_BEHAVIORS] = None, + format: Optional[Union[CACHE_MATERIALIZERS, str]] = None, + target_: base.TargetType = ..., + ): + """The ``@cache`` decorator can define the behavior and format of a specific node. + + This feature is implemented via tags, but that could change. Thus you should not + rely on these tags for other purposes. + + .. code-block:: python + + @cache(behavior="recompute", format="parquet") + def raw_data() -> pd.DataFrame: ... + + + If the function uses other function modifiers and define multiple nodes, you can + set ``target_`` to specify which nodes to cache. The following only caches the ``performance`` node. + + .. code-block:: python + + @cache(format="json", target_="performance") + @extract_fields(trained_model=LinearRegression, performance: dict) + def model_training() -> dict: + # ... + performance = {"rmse": 0.1, "mae": 0.2} + return {"trained_model": trained_model, "performance": performance} + + + :param behavior: The behavior of the cache. This can be one of the following: + * **default**: caching is enabled + * **recompute**: always compute the node instead of retrieving + * **ignore**: the data version won't be part of downstream keys + * **disable**: act as if caching wasn't enabled. + :param format: The format of the cache. This can be one of the following: + * **json**: JSON format + * **file**: file format + * **pickle**: pickle format + * **parquet**: parquet format + * **csv**: csv format + * **feather**: feather format + * **orc**: orc format + * **excel**: excel format + :param target\\_: Target nodes to decorate. This can be one of the following: + * **None**: tag all nodes outputted by this that are "final" (E.g. do not have a node\ + outputted by this that depend on them) + * **Ellipsis (...)**: tag *all* nodes outputted by this + * **Collection[str]**: tag *only* the nodes with the specified names + * **str**: tag *only* the node with the specified name + """ + super(cache, self).__init__(target=target_) + + # don't provide default value for behavior and format if not provided by user + # the SmartCacheAdapter expects the field to be empty if not set + self.cache_tags = {} + if behavior: + self.cache_tags[cache.BEHAVIOR_KEY] = behavior + + if format: + self.cache_tags[cache.FORMAT_KEY] = format + + def decorate_node(self, node_: node.Node) -> node.Node: + """Decorates the nodes with the cache tags. + + :param node_: Node to decorate + :return: Copy of the node, with tags assigned + """ + node_tags = node_.tags.copy() + node_tags.update(self.cache_tags) + return node_.copy_with(tags=node_tags) diff --git a/hamilton/graph.py b/hamilton/graph.py index 8ea97b285..a6246e5d3 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -505,7 +505,7 @@ def _get_legend( # we use tags to identify what is a data loader # but we have two ways that we need to capture, hence the clauses. if n.tags.get("hamilton.data_loader") and ( - "load_data." in n.name or "loader" == n.tags.get("hamilton.data_loader.node") + "load_data." in n.name or "loader" == n.tags.get("hamilton.data_loader.source") ): materializer_type = n.tags["hamilton.data_loader.classname"] label = _get_node_label(n, type_string=materializer_type) diff --git a/hamilton/lifecycle/default.py b/hamilton/lifecycle/default.py index 3966ceca4..313315b24 100644 --- a/hamilton/lifecycle/default.py +++ b/hamilton/lifecycle/default.py @@ -357,6 +357,13 @@ def __init__( self.used_nodes_hash: Dict[str, str] = dict() self.cache.close() + logger.warning( + "The `CacheAdapter` is deprecated and will be removed in Hamilton 2.0. " + "Consider enabling the core caching feature via `Builder.with_cache()`. " + "This might not be 1-to-1 replacement, so please reach out if there are missing features. " + "See https://hamilton.dagworks.io/en/latest/concepts/caching/ to learn more." + ) + def run_before_graph_execution(self, *, graph: HamiltonGraph, **kwargs): """Set `cache_vars` to all nodes if received None during `__init__`""" self.cache = shelve.open(self.cache_path) diff --git a/hamilton/plugins/h_diskcache.py b/hamilton/plugins/h_diskcache.py index a1d710274..f2453b20c 100644 --- a/hamilton/plugins/h_diskcache.py +++ b/hamilton/plugins/h_diskcache.py @@ -87,6 +87,13 @@ def __init__( ) # type: ignore self.used_nodes_hash: Dict[str, str] = dict() + logger.warning( + "The `DiskCacheAdapter` is deprecated and will be removed in Hamilton 2.0. " + "Consider enabling the core caching feature via `Builder.with_cache()`. " + "This might not be 1-to-1 replacement, so please reach out if there are missing features. " + "See https://hamilton.dagworks.io/en/latest/concepts/caching/ to learn more." + ) + def run_before_graph_execution(self, *, graph: graph_types.HamiltonGraph, **kwargs): """Set cache_vars to all nodes if not specified""" if self.cache_vars == []: diff --git a/pyproject.toml b/pyproject.toml index 5660d30a7..609a92ccb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ docs = [ "lz4", "mlflow", "mock==1.0.1", # read the docs pins - "myst-parser==2.0.0", # latest version of myst at this time + "myst-nb", "narwhals", "numpy < 2.0.0", "pandera", diff --git a/tests/caching/__init__.py b/tests/caching/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/caching/test_adapter.py b/tests/caching/test_adapter.py new file mode 100644 index 000000000..08e6bf1da --- /dev/null +++ b/tests/caching/test_adapter.py @@ -0,0 +1,150 @@ +import pytest + +import hamilton.node +from hamilton.caching import fingerprinting +from hamilton.caching.adapter import ( + CachingBehavior, + CachingEventType, + HamiltonCacheAdapter, +) +from hamilton.function_modifiers.metadata import cache as cache_decorator +from hamilton.graph import FunctionGraph + + +@pytest.fixture +def cache_adapter(tmp_path): + def foo() -> str: + return "hello-world" + + run_id = "my-run-id" + adapter = HamiltonCacheAdapter(path=tmp_path) + adapter.metadata_store.initialize(run_id) + adapter._fn_graphs[run_id] = FunctionGraph( + nodes={"foo": hamilton.node.Node.from_fn(foo)}, + config={}, + ) + adapter.behaviors = {run_id: {"foo": CachingBehavior.DEFAULT}} + adapter._data_savers = {run_id: []} + adapter._data_loaders = {run_id: []} + adapter.run_ids.append(run_id) + adapter.data_versions = {run_id: {}} + adapter.code_versions = {run_id: {"foo": "0", "bar": "0"}} + adapter.cache_keys = {run_id: {}} + adapter._logs = {run_id: []} + + yield adapter + + adapter.metadata_store.delete_all() + + +def test_post_node_execute_set_result(cache_adapter): + """Adapter should write to cache and repository if it needs to compute the value""" + node_name = "foo" + result = 123 + data_version = fingerprinting.hash_value(result) + run_id = cache_adapter.last_run_id + + assert cache_adapter.data_versions.get(node_name) is None + assert cache_adapter.result_store.exists(data_version) is False + + cache_adapter.data_versions[run_id][node_name] = data_version + cache_adapter.post_node_execute( + run_id=run_id, # latest run_id + node_=cache_adapter._fn_graphs[run_id].nodes[node_name], + kwargs={}, + result=result, + ) + + assert cache_adapter.result_store.exists(data_version) + assert cache_adapter.result_store.get(data_version=data_version) == result + + +def test_do_execute_reads_data_version_directly_from_memory(cache_adapter): + """Adapter shouldn't check the repository if fingerprint is available""" + + def foo() -> int: + return 123 + + cached_result = foo() + data_version = fingerprinting.hash_value(cached_result) + run_id = cache_adapter.last_run_id + + cache_adapter.data_versions[run_id]["foo"] = data_version + cache_adapter.pre_node_execute(run_id=run_id, node_=hamilton.node.Node.from_fn(foo), kwargs={}) + + assert not any( + event.event_type == CachingEventType.GET_DATA_VERSION and event.actor == "metadata_store" + for event in cache_adapter.logs(run_id) + ) + + +def test_run_to_execute_repo_cache_desync(cache_adapter): + """The adapter determines the value is in cache, + but there's an error loading the value from cache. + + The adapter should delete metadata store keys to force recompute and + writing the result to cache + + NOTE that this will only log and error and not raise any Exception. + This is because adapters cannot currently raise Exception that stop + the main execution. + """ + + def foo() -> int: + return 123 + + run_id = cache_adapter.last_run_id # latest run_id + + # set data version in memory + cache_adapter.data_versions[run_id]["foo"] = fingerprinting.hash_value(foo()) + cache_adapter.do_node_execute(run_id=run_id, node_=hamilton.node.Node.from_fn(foo), kwargs={}) + + # found data version in memory, but the value wasn't in cache + # forcing deletion from metadata_store and recompute + logs = cache_adapter.logs(run_id, level="debug") + assert any(event.event_type == CachingEventType.MISSING_RESULT for event in logs["foo"]) + assert any(event.event_type == CachingEventType.EXECUTE_NODE for event in logs["foo"]) + + +@pytest.mark.parametrize( + "behavior", + [ + "default", + "recompute", + "disable", + "ignore", + ], +) +def test_cache_tag_resolved(cache_adapter, behavior): + node = cache_adapter._fn_graphs[cache_adapter.last_run_id].nodes["foo"] + node._tags = {cache_decorator.BEHAVIOR_KEY: behavior} + resolved_behavior = HamiltonCacheAdapter._resolve_node_behavior(node) + assert resolved_behavior == CachingBehavior.from_string(behavior) + + +def test_default_behavior(cache_adapter): + h_node = cache_adapter._fn_graphs[cache_adapter.last_run_id].nodes["foo"] + resolved_behavior = HamiltonCacheAdapter._resolve_node_behavior(h_node) + assert resolved_behavior == CachingBehavior.DEFAULT + + +def test_driver_behavior_overrides_cache_tag(cache_adapter): + node_name = "foo" + node = cache_adapter._fn_graphs[cache_adapter.last_run_id].nodes[node_name] + node._tags = {cache_decorator.BEHAVIOR_KEY: "recompute"} + + resolved_behavior = HamiltonCacheAdapter._resolve_node_behavior(node=node, disable=[node_name]) + + assert resolved_behavior == CachingBehavior.DISABLE + + +def test_raise_if_multiple_driver_behavior_for_same_node(cache_adapter): + node_name = "foo" + node = cache_adapter._fn_graphs[cache_adapter.last_run_id].nodes[node_name] + + with pytest.raises(ValueError): + HamiltonCacheAdapter._resolve_node_behavior( + node, + disable=[node_name], + recompute=[node_name], + ) diff --git a/tests/caching/test_fingerprinting.py b/tests/caching/test_fingerprinting.py new file mode 100644 index 000000000..b8bd17111 --- /dev/null +++ b/tests/caching/test_fingerprinting.py @@ -0,0 +1,186 @@ +"""Due to the recursive nature of hashing of sequences, mappings, and other +complex types, many tests are not "true" unit tests. The base cases are +the original `hash_value()` and the `hash_primitive()` functions. +""" + +import numpy as np +import pandas as pd +import pytest + +from hamilton.caching import fingerprinting + + +def test_hash_none(): + fingerprint = fingerprinting.hash_value(None) + assert fingerprint == "" + + +def test_hash_no_dict_attribute(): + """Classes without a __dict__ attribute can't be hashed. + during the base case. + """ + + class Foo: + __slots__ = () + + def __init__(self): + pass + + obj = Foo() + fingerprint = fingerprinting.hash_value(obj) + assert not hasattr(obj, "__dict__") + assert fingerprint == fingerprinting.UNHASHABLE + + +def test_hash_recursively(): + """Classes without a specialized hash function are hashed recursively + via their __dict__ attribute. + """ + + class Foo: + def __init__(self, obj): + self.foo = "foo" + self.obj = obj + + foo0 = Foo(obj=None) + foo1 = Foo(obj=foo0) + foo2 = Foo(obj=foo1) + + foo0_dict = {"foo": "foo", "obj": None} + foo1_dict = {"foo": "foo", "obj": foo0_dict} + foo2_dict = {"foo": "foo", "obj": foo1_dict} + + assert foo0.__dict__ == foo0_dict + # NOTE foo2.__dict__ != foo2_dict, because foo2.__dict__ holds + # a reference to the object foo1, which is not the case for foo2_dict + + fingerprint0 = fingerprinting.hash_value(foo0) + assert fingerprint0 == fingerprinting.hash_value(foo0_dict) + + fingerprint1 = fingerprinting.hash_value(foo1) + assert fingerprint1 == fingerprinting.hash_value(foo1_dict) + + fingerprint2 = fingerprinting.hash_value(foo2) + assert fingerprint2 == fingerprinting.hash_value(foo2_dict) + + +def test_max_recursion_depth(): + """Set the max recursion depth to 0 to prevent any recursion. + After max depth, the default case should return UNHASHABLE. + """ + + class Foo: + def __init__(self, obj): + self.foo = "foo" + self.obj = obj + + foo0 = Foo(obj=None) + foo1 = Foo(obj=foo0) + foo2 = Foo(obj=foo1) + + foo0_dict = {"foo": "foo", "obj": None} + assert foo0.__dict__ == foo0_dict + + fingerprint0 = fingerprinting.hash_value(foo0) + assert fingerprint0 == fingerprinting.hash_value(foo0_dict) + + fingerprinting.set_max_depth(1) + # equivalent after reaching max depth + fingerprint1 = fingerprinting.hash_value(foo1) + fingerprint2 = fingerprinting.hash_value(foo2) + assert fingerprint1 == fingerprint2 + + fingerprinting.set_max_depth(2) + # no longer equivalent after increasing max depth + fingerprint1 = fingerprinting.hash_value(foo1) + fingerprint2 = fingerprinting.hash_value(foo2) + assert fingerprint1 != fingerprint2 + + +@pytest.mark.parametrize( + "obj,expected_hash", + [ + ("hello-world", "IJUxIYl1PeatR9_iDL6X7A=="), + (17.31231, "vAYX8MD8yEHK6dwnIPVUaw=="), + (16474, "L_epMRRUy3Qq5foVvFT_OQ=="), + (True, "-CfPRi9ihI3zfF4elKTadA=="), + (b"\x951!\x89u=\xe6\xadG\xdf", "qK2VJ0vVTRJemfC0beO8iA=="), + ], +) +def test_hash_primitive(obj, expected_hash): + fingerprint = fingerprinting.hash_primitive(obj) + assert fingerprint == expected_hash + + +@pytest.mark.parametrize( + "obj,expected_hash", + [ + ([0, True, "hello-world"], "Pg9LP3Y-8yYsoWLXedPVKDwTAa7W8_fjJNTTUA=="), + ((17.0, False, "world"), "wyuuKMuL8rp53_CdYAtyMmyetnTJ9LzmexhJrQ=="), + ], +) +def test_hash_sequence(obj, expected_hash): + fingerprint = fingerprinting.hash_sequence(obj) + assert fingerprint == expected_hash + + +def test_hash_equals_for_different_sequence_types(): + list_obj = [0, True, "hello-world"] + tuple_obj = (0, True, "hello-world") + expected_hash = "Pg9LP3Y-8yYsoWLXedPVKDwTAa7W8_fjJNTTUA==" + + list_fingerprint = fingerprinting.hash_sequence(list_obj) + tuple_fingerprint = fingerprinting.hash_sequence(tuple_obj) + assert list_fingerprint == tuple_fingerprint == expected_hash + + +def test_hash_ordered_mapping(): + obj = {0: True, "key": "value", 17.0: None} + expected_hash = "1zH9TfTu0-nlWXXXYo0vigFFSQajWXov2w4AZQ==" + fingerprint = fingerprinting.hash_mapping(obj, ignore_order=False) + assert fingerprint == expected_hash + + +def test_hash_mapping_where_order_matters(): + obj1 = {0: True, "key": "value", 17.0: None} + obj2 = {"key": "value", 17.0: None, 0: True} + fingerprint1 = fingerprinting.hash_mapping(obj1, ignore_order=False) + fingerprint2 = fingerprinting.hash_mapping(obj2, ignore_order=False) + assert fingerprint1 != fingerprint2 + + +def test_hash_unordered_mapping(): + obj = {0: True, "key": "value", 17.0: None} + expected_hash = "uw0dfSAEgE9nOK3bHgmJ4TR3-VFRqOAoogdRmw==" + fingerprint = fingerprinting.hash_mapping(obj, ignore_order=True) + assert fingerprint == expected_hash + + +def test_hash_mapping_where_order_doesnt_matter(): + obj1 = {0: True, "key": "value", 17.0: None} + obj2 = {"key": "value", 17.0: None, 0: True} + fingerprint1 = fingerprinting.hash_mapping(obj1, ignore_order=True) + fingerprint2 = fingerprinting.hash_mapping(obj2, ignore_order=True) + assert fingerprint1 == fingerprint2 + + +def test_hash_set(): + obj = {0, True, "key", "value", 17.0, None} + expected_hash = "dKyAE-ob4_GD-Mb5Lu2R-VJAxGctY4L8JDwc2g==" + fingerprint = fingerprinting.hash_set(obj) + assert fingerprint == expected_hash + + +def test_hash_pandas(): + """pandas has a specialized hash function""" + obj = pd.DataFrame({"a": [1, 2], "b": ["x", "y"]}) + expected_hash = "LSHACWyG83JBIggxO9LGrerW3WZEy4nUOmIQoA==" + fingerprint = fingerprinting.hash_pandas_obj(obj) + assert fingerprint == expected_hash + + +def test_hash_numpy(): + array = np.array([[0, 1], [2, 3]]) + expected_hash = "ZwjDgY0zQOxO9KPHlYecog==" + fingerprint = fingerprinting.hash_value(array) + assert fingerprint == expected_hash diff --git a/tests/caching/test_integration.py b/tests/caching/test_integration.py new file mode 100644 index 000000000..cb14c33e1 --- /dev/null +++ b/tests/caching/test_integration.py @@ -0,0 +1,619 @@ +from typing import List + +import pandas as pd +import pytest + +from hamilton import ad_hoc_utils, driver +from hamilton.caching.adapter import CachingEventType, HamiltonCacheAdapter +from hamilton.execution.executors import ( + MultiProcessingExecutor, + MultiThreadingExecutor, + SynchronousLocalTaskExecutor, +) +from hamilton.function_modifiers import cache as cache_decorator + +from tests.resources.dynamic_parallelism import parallel_linear_basic, parallelism_with_caching + + +@pytest.fixture +def dr(request, tmp_path): + module = request.param + return driver.Builder().with_modules(module).with_cache(path=tmp_path).build() + + +def execute_dataflow( + module, + cache, + final_vars: list, + config: dict = None, + inputs: dict = None, + overrides: dict = None, +) -> dict: + config = config if config else {} + inputs = inputs if inputs else {} + overrides = overrides if overrides else {} + + dr = driver.Builder().with_modules(module).with_adapters(cache).with_config(config).build() + results = dr.execute(final_vars, inputs=inputs, overrides=overrides) + return results + + +def check_execution(cache, did: List[str] = None, did_not: List[str] = None): + did = did if did is not None else [] + did_not = did_not if did_not is not None else [] + + latest_logs = cache.logs(cache.last_run_id, level="debug") + for did_name in did: + assert any(e.event_type == CachingEventType.EXECUTE_NODE for e in latest_logs[did_name]) + + for did_not_name in did_not: + assert not any( + e.event_type == CachingEventType.EXECUTE_NODE for e in latest_logs[did_not_name] + ) + + +def check_execution_task_based(cache, did: List[str] = None, did_not: List[str] = None): + did = did if did is not None else [] + did_not = did_not if did_not is not None else [] + + latest_logs = cache.logs(cache.last_run_id, level="debug") + for key in latest_logs: + if not isinstance(key, tuple): + # keys that aren't (node_name, task_id) tuples are from the `code_version` event + continue + + node_name, task_id = key + if node_name in did: + assert any(e.event_type == CachingEventType.EXECUTE_NODE for e in latest_logs[key]) + elif node_name in did_not: + assert not any(e.event_type == CachingEventType.EXECUTE_NODE for e in latest_logs[key]) + + +def check_metadata_store_size(cache, size: int): + assert cache.metadata_store.size == size + + +def check_results_exist_in_store(cache, expected_nodes): + run_metadata = cache.metadata_store.get_run(cache.last_run_id) + + for entry in run_metadata: + data_version = entry["data_version"] + if isinstance(data_version, dict): + for item_data_version in data_version.values(): + assert cache.result_store.exists(item_data_version) + else: + assert cache.result_store.exists(data_version) + + +def node_A(): + def A() -> int: + return 1 + + return A + + +def node_A_code_change_same_result(): + def A() -> int: + return 1 + 0 + + return A + + +def node_A_code_change_different_result(): + def A() -> int: + return 2 + + return A + + +def node_B_depends_on_A(): + def B(A: int) -> int: + return 0 - A + + return B + + +def node_B_raises(): + def B(A: int) -> int: + raise ValueError() + + return B + + +def node_C_depends_on_B(): + def C(B: int) -> int: + return B + 1 + + return C + + +def test_code_change_same_result_do_recompute(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module_1 = ad_hoc_utils.create_temporary_module(node_A()) + module_2 = ad_hoc_utils.create_temporary_module(node_A_code_change_same_result()) + final_vars = ["A"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A"]) + check_metadata_store_size(cache=cache, size=1) + check_results_exist_in_store(cache, ["A"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A"]) + check_metadata_store_size(cache=cache, size=1) + check_results_exist_in_store(cache, ["A"]) + assert results_2 == results_1 + + # execution 3: execute after code change. + execute_dataflow(module=module_2, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A"]) + + +def test_code_change_different_result_do_recompute(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module_1 = ad_hoc_utils.create_temporary_module(node_A()) + module_2 = ad_hoc_utils.create_temporary_module(node_A_code_change_different_result()) + final_vars = ["A"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A"]) + check_metadata_store_size(cache=cache, size=1) + check_results_exist_in_store(cache, ["A"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A"]) + check_metadata_store_size(cache=cache, size=1) + check_results_exist_in_store(cache, ["A"]) + assert results_2 == results_1 + + # execution 3: execute after code change. + execute_dataflow(module=module_2, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A"]) + + +def test_input_data_change_do_recompute(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module = ad_hoc_utils.create_temporary_module(node_B_depends_on_A()) + final_vars = ["B"] + inputs1 = {"A": 1} + inputs2 = {"A": 2} + + # execution 1: populate cache + results_1 = execute_dataflow(module=module, cache=cache, final_vars=final_vars, inputs=inputs1) + check_execution(cache=cache, did=["B"]) + check_metadata_store_size(cache=cache, size=1) + check_results_exist_in_store(cache, ["B"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module, cache=cache, final_vars=final_vars, inputs=inputs1) + check_execution(cache=cache, did_not=["B"]) + check_metadata_store_size(cache=cache, size=1) + check_results_exist_in_store(cache, ["B"]) + assert results_2 == results_1 + + # execution 3: execute with input data change + execute_dataflow(module=module, cache=cache, final_vars=final_vars, inputs=inputs2) + check_execution(cache=cache, did=["B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["B"]) + + +def test_dependency_code_change_same_result_dont_recompute(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module_1 = ad_hoc_utils.create_temporary_module(node_A(), node_B_depends_on_A()) + module_2 = ad_hoc_utils.create_temporary_module( + node_A_code_change_same_result(), node_B_depends_on_A() + ) + final_vars = ["B"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + assert results_2 == results_1 + + # execution 3: execute with dependency code change + execute_dataflow(module=module_2, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A"], did_not=["B"]) + check_metadata_store_size(cache=cache, size=3) + check_results_exist_in_store(cache, ["A", "B"]) + + +def test_dependency_code_change_different_result_do_recompute(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module_1 = ad_hoc_utils.create_temporary_module(node_A(), node_B_depends_on_A()) + module_2 = ad_hoc_utils.create_temporary_module( + node_A_code_change_different_result(), node_B_depends_on_A() + ) + final_vars = ["B"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + assert results_2 == results_1 + + # execution 3: execute with dependency code change + execute_dataflow(module=module_2, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=4) + check_results_exist_in_store(cache, ["A", "B"]) + + +def test_override_with_same_value_dont_recompute(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module = ad_hoc_utils.create_temporary_module(node_A(), node_B_depends_on_A()) + overrides = {"A": 1} + final_vars = ["B"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + assert results_2 == results_1 + + # execution 3: execute with override + execute_dataflow(module=module, cache=cache, final_vars=final_vars, overrides=overrides) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + +def test_override_with_different_value_do_recompute(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module = ad_hoc_utils.create_temporary_module(node_A(), node_B_depends_on_A()) + overrides = {"A": 13} + final_vars = ["B"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + assert results_2 == results_1 + + # execution 3: execute with override + execute_dataflow(module=module, cache=cache, final_vars=final_vars, overrides=overrides) + check_execution(cache=cache, did=["B"], did_not=["A"]) + check_metadata_store_size(cache=cache, size=3) + check_results_exist_in_store(cache, ["B"]) # overrides are not stored + + +def test_node_that_raises_error(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module_1 = ad_hoc_utils.create_temporary_module(node_A(), node_B_depends_on_A()) + module_2 = ad_hoc_utils.create_temporary_module(node_A(), node_B_raises()) + final_vars = ["B"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module_1, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + assert results_2 == results_1 + + # execution 3: execute with raising node + # B doesn't count as `did execute` because it raised an Exception + with pytest.raises(ValueError): + execute_dataflow(module=module_2, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A"]) + assert any( + e.event_type == CachingEventType.FAILED_EXECUTION + for e in cache.logs(cache.run_ids[-1], level="debug")["B"] + ) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A"]) + + +def test_caching_pandas_dataframe(tmp_path): + def A() -> pd.DataFrame: + return pd.DataFrame({"foo": [0, 1], "bar": ["a", "b"]}) + + def B(A: pd.DataFrame) -> pd.DataFrame: + A["baz"] = pd.Series([True, False]) + return A + + cache = HamiltonCacheAdapter(path=tmp_path) + module = ad_hoc_utils.create_temporary_module(A, B) + final_vars = ["B"] + + # execution 1: populate cache + execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 2: retrieve under the same condition + execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + +def test_recompute_behavior(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module = ad_hoc_utils.create_temporary_module(node_A(), node_B_depends_on_A()) + final_vars = ["B"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + assert results_2 == results_1 + + cache._recompute = ["A"] + # execution 3: force recompute A + # metadata size doesn't increase because it's a duplicate entry + execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A"], did_not=["B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + +def test_disable_behavior(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module = ad_hoc_utils.create_temporary_module( + node_A(), node_B_depends_on_A(), node_C_depends_on_B() + ) + final_vars = ["C"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B", "C"]) + check_metadata_store_size(cache=cache, size=3) + check_results_exist_in_store(cache, ["A", "B", "C"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B", "C"]) + check_metadata_store_size(cache=cache, size=3) + check_results_exist_in_store(cache, ["A", "B", "C"]) + assert results_2 == results_1 + + cache._disable = ["A"] + # execution 3: disable A means it forces reexecution of dependent nodes + # A doesn't produce any metadata or result + # metadata size grows for each rexecution of B + execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"], did_not=["C"]) + check_metadata_store_size(cache=cache, size=4) + check_results_exist_in_store(cache, ["A", "B", "C"]) + + # execution 4: keeps forcing re-execution of A and B as long as A is DISABLE + execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"], did_not=["C"]) + check_metadata_store_size(cache=cache, size=5) + check_results_exist_in_store(cache, ["A", "B", "C"]) + + +def test_ignore_behavior(tmp_path): + cache = HamiltonCacheAdapter(path=tmp_path) + module = ad_hoc_utils.create_temporary_module(node_A(), node_B_depends_on_A()) + final_vars = ["B"] + + # execution 1: populate cache + results_1 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 2: retrieve under the same condition + results_2 = execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did_not=["A", "B"]) + check_metadata_store_size(cache=cache, size=2) + check_results_exist_in_store(cache, ["A", "B"]) + assert results_2 == results_1 + + cache._ignore = ["A"] + # execution 3: a new key that ignores A will be recomputed for B + # A doesn't produce any metadata or result + execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A", "B"]) + check_metadata_store_size(cache=cache, size=4) + check_results_exist_in_store(cache, ["A", "B"]) + + # execution 4: B can be retrieved using its new key that ignores A + execute_dataflow(module=module, cache=cache, final_vars=final_vars) + check_execution(cache=cache, did=["A"], did_not=["B"]) + check_metadata_store_size(cache=cache, size=5) + check_results_exist_in_store(cache, ["A", "B"]) + + +def test_result_is_materialized_to_file(tmp_path): + @cache_decorator(format="json") + def foo() -> dict: + return {"hello": "world"} + + node_name = "foo" + module = ad_hoc_utils.create_temporary_module(foo) + dr = driver.Builder().with_modules(module).with_cache(path=tmp_path).build() + + result = dr.execute([node_name]) + data_version = dr.cache.version_data(result[node_name]) + retrieved_result = dr.cache.result_store.get(data_version) + + assert result[node_name] == retrieved_result + + +@pytest.mark.parametrize( + "executor", + [ + SynchronousLocalTaskExecutor(), + MultiProcessingExecutor(max_tasks=10), + MultiThreadingExecutor(max_tasks=10), + ], +) +def test_parallel_synchronous_step_by_step(tmp_path, executor): + dr = ( + driver.Builder() + .with_modules(parallel_linear_basic) + .with_cache(path=tmp_path) + .enable_dynamic_execution(allow_experimental_mode=True) + .with_remote_executor(executor) + .build() + ) + + dr.execute(["final"]) + check_execution_task_based( + cache=dr.cache, + did=[ + "number_of_steps", + "steps", + "step_squared", + "step_cubed", + "step_squared_plus_step_cubed", + "sum_step_squared_plus_step_cubed", + "final", + ], + ) + check_metadata_store_size(cache=dr.cache, size=22) + check_results_exist_in_store( + cache=dr.cache, + expected_nodes=[ + "number_of_steps", + "steps", + "step_squared", + "step_cubed", + "step_squared_plus_step_cubed", + "sum_step_squared_plus_step_cubed", + "final", + ], + ) + + # execution 2: expand node `steps` must be recomputed because of the iterator. + dr.execute(["final"]) + check_execution_task_based( + cache=dr.cache, + did=["steps"], + did_not=[ + "number_of_steps", + "step_squared", + "step_cubed", + "step_squared_plus_step_cubed", + "sum_step_squared_plus_step_cubed", + "final", + ], + ) + check_metadata_store_size(cache=dr.cache, size=22) + check_results_exist_in_store( + cache=dr.cache, + expected_nodes=[ + "number_of_steps", + "steps", + "step_squared", + "step_cubed", + "step_squared_plus_step_cubed", + "sum_step_squared_plus_step_cubed", + "final", + ], + ) + + +@pytest.mark.parametrize( + "executor", + [ + SynchronousLocalTaskExecutor(), + MultiProcessingExecutor(max_tasks=10), + MultiThreadingExecutor(max_tasks=10), + ], +) +def test_materialize_parallel_branches(tmp_path, executor): + # NOTE the module can't be defined here because multithreading requires functions to be top-level. + dr = ( + driver.Builder() + .with_modules(parallelism_with_caching) + .with_cache(path=tmp_path) + .enable_dynamic_execution(allow_experimental_mode=True) + .with_remote_executor(executor) + .build() + ) + + # execution 1 + dr.execute(["collect_node"]) + check_execution_task_based(cache=dr.cache, did=["expand_node", "inside_branch", "collect_node"]) + check_metadata_store_size(cache=dr.cache, size=10) + check_results_exist_in_store( + cache=dr.cache, expected_nodes=["expand_node", "inside_branch", "collect_node"] + ) + + # execution 2: expand node must be recomputed because of the iterator. + # values for `inside_branch` are retrieved from the JSON materialization + dr.execute(["collect_node"]) + check_execution_task_based( + cache=dr.cache, did=["expand_node"], did_not=["inside_branch", "collect_node"] + ) + check_metadata_store_size(cache=dr.cache, size=10) + check_results_exist_in_store( + cache=dr.cache, expected_nodes=["expand_node", "inside_branch", "collect_node"] + ) + + +def test_consistent_cache_key_with_or_without_defaut_parameter(tmp_path): + def foo(external_dep: int = 3) -> int: + return external_dep + 1 + + cache = HamiltonCacheAdapter(path=tmp_path) + module = ad_hoc_utils.create_temporary_module(foo) + final_vars = ["foo"] + inputs_1 = {} + inputs_2 = {"external_dep": 3} + + # execution 1: populate cache + execute_dataflow(module=module, cache=cache, final_vars=final_vars, inputs=inputs_1) + check_execution(cache=cache, did=["foo"]) + cache_key_1 = cache.cache_keys[cache.last_run_id]["foo"] + + # execution 2: retrieve under the same condition + execute_dataflow(module=module, cache=cache, final_vars=final_vars, inputs=inputs_2) + check_execution(cache=cache, did_not=["foo"]) + cache_key_2 = cache.cache_keys[cache.last_run_id]["foo"] + + assert cache_key_1 == cache_key_2 diff --git a/tests/caching/test_metadata_store.py b/tests/caching/test_metadata_store.py new file mode 100644 index 000000000..ab3d61751 --- /dev/null +++ b/tests/caching/test_metadata_store.py @@ -0,0 +1,100 @@ +import pytest + +from hamilton.caching.cache_key import create_cache_key +from hamilton.caching.stores.sqlite import SQLiteMetadataStore + + +@pytest.fixture +def metadata_store(request, tmp_path): + metdata_store_cls = request.param + metadata_store = metdata_store_cls(path=tmp_path) + run_id = "test-run-id" + try: + metadata_store.initialize(run_id) + except BaseException: + pass + + yield metadata_store + + metadata_store.delete_all() + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_initialize_empty(metadata_store): + assert metadata_store.size == 0 + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_not_empty_after_set(metadata_store): + code_version = "FOO-1" + data_version = "foo-a" + node_name = "foo" + cache_key = create_cache_key( + node_name=node_name, code_version=code_version, dependencies_data_versions={} + ) + + metadata_store.set( + cache_key=cache_key, + node_name=node_name, + code_version=code_version, + data_version=data_version, + run_id="...", + ) + + assert metadata_store.size > 0 + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_set_doesnt_produce_duplicates(metadata_store): + code_version = "FOO-1" + data_version = "foo-a" + node_name = "foo" + cache_key = create_cache_key( + node_name=node_name, code_version=code_version, dependencies_data_versions={} + ) + metadata_store.set( + cache_key=cache_key, + node_name=node_name, + code_version=code_version, + data_version=data_version, + run_id="...", + ) + assert metadata_store.size == 1 + + metadata_store.set( + cache_key=cache_key, + node_name=node_name, + code_version=code_version, + data_version=data_version, + run_id="...", + ) + assert metadata_store.size == 1 + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_get_miss_returns_none(metadata_store): + cache_key = create_cache_key( + node_name="foo", code_version="FOO-1", dependencies_data_versions={"bar": "bar-a"} + ) + data_version = metadata_store.get(cache_key=cache_key) + assert data_version is None + + +@pytest.mark.parametrize("metadata_store", [SQLiteMetadataStore], indirect=True) +def test_set_get_without_dependencies(metadata_store): + code_version = "FOO-1" + data_version = "foo-a" + node_name = "foo" + cache_key = create_cache_key( + node_name=node_name, code_version=code_version, dependencies_data_versions={} + ) + metadata_store.set( + cache_key=cache_key, + node_name=node_name, + code_version=code_version, + data_version=data_version, + run_id="...", + ) + retrieved_data_version = metadata_store.get(cache_key=cache_key) + + assert retrieved_data_version == data_version diff --git a/tests/caching/test_result_store.py b/tests/caching/test_result_store.py new file mode 100644 index 000000000..9afbd455a --- /dev/null +++ b/tests/caching/test_result_store.py @@ -0,0 +1,116 @@ +import pathlib +import pickle + +import pytest + +from hamilton.caching import fingerprinting +from hamilton.caching.stores.base import search_data_adapter_registry +from hamilton.caching.stores.file import FileResultStore + + +@pytest.fixture +def result_store(tmp_path): + store = FileResultStore(tmp_path / "h-cache") + + yield store + + store.delete_all() + + +def check_result_store_size(result_store, size: int): + assert len([p for p in result_store.path.iterdir()]) == size + + +def test_set(result_store): + data_version = "foo" + assert not pathlib.Path(result_store.path, data_version).exists() + + result_store.set(data_version=data_version, result="bar") + + assert pathlib.Path(result_store.path, data_version).exists() + check_result_store_size(result_store, size=1) + + +def test_exists(result_store): + data_version = "foo" + assert ( + result_store.exists(data_version) == pathlib.Path(result_store.path, data_version).exists() + ) + + result_store.set(data_version=data_version, result="bar") + + assert ( + result_store.exists(data_version) == pathlib.Path(result_store.path, data_version).exists() + ) + + +def test_set_doesnt_produce_duplicates(result_store): + data_version = "foo" + assert not result_store.exists(data_version) + + result_store.set(data_version=data_version, result="bar") + result_store.set(data_version=data_version, result="bar") + + assert result_store.exists(data_version) + check_result_store_size(result_store, size=1) + + +def test_get(result_store): + data_version = "foo" + result = "bar" + pathlib.Path(result_store.path, data_version).open("wb").write(pickle.dumps(result)) + assert result_store.exists(data_version) + + retrieved_value = result_store.get(data_version) + + assert retrieved_value + assert result == retrieved_value + check_result_store_size(result_store, size=1) + + +def test_get_missing_result_is_none(result_store): + result = result_store.get("foo") + assert result is None + + +def test_delete(result_store): + data_version = "foo" + result_store.set(data_version, "bar") + assert pathlib.Path(result_store.path, data_version).exists() + check_result_store_size(result_store, size=1) + + result_store.delete(data_version) + + assert not pathlib.Path(result_store.path, data_version).exists() + check_result_store_size(result_store, size=0) + + +def test_delete_all(result_store): + result_store.set("foo", "foo") + result_store.set("bar", "bar") + check_result_store_size(result_store, size=2) + + result_store.delete_all() + + check_result_store_size(result_store, size=0) + + +@pytest.mark.parametrize( + "format,value", + [ + ("json", {"key1": "value1", "key2": 2}), + ("pickle", ("value1", "value2", "value3")), + ], +) +def test_save_and_load_materializer(format, value, result_store): + saver_cls, loader_cls = search_data_adapter_registry(name=format, type_=type(value)) + data_version = "foo" + materialized_path = result_store._materialized_path(data_version, saver_cls) + + result_store.set( + data_version=data_version, result=value, saver_cls=saver_cls, loader_cls=loader_cls + ) + retrieved_value = result_store.get(data_version) + + assert materialized_path.exists() + assert fingerprinting.hash_value(value) == fingerprinting.hash_value(retrieved_value) diff --git a/tests/function_modifiers/test_adapters.py b/tests/function_modifiers/test_adapters.py index c79f82370..08f2bb067 100644 --- a/tests/function_modifiers/test_adapters.py +++ b/tests/function_modifiers/test_adapters.py @@ -777,8 +777,8 @@ def test_dataloader(): "hamilton.data_loader": True, "hamilton.data_loader.classname": "correct_dl_function()", "hamilton.data_loader.has_metadata": True, - "hamilton.data_loader.node": "loader", - "hamilton.data_loader.source": "correct_dl_function", + "hamilton.data_loader.node": "correct_dl_function", + "hamilton.data_loader.source": "loader", "module": "tests.function_modifiers.test_adapters", } assert node2.name == "correct_dl_function" diff --git a/tests/resources/dynamic_parallelism/parallelism_with_caching.py b/tests/resources/dynamic_parallelism/parallelism_with_caching.py new file mode 100644 index 000000000..0f24cd137 --- /dev/null +++ b/tests/resources/dynamic_parallelism/parallelism_with_caching.py @@ -0,0 +1,16 @@ +from hamilton import htypes +from hamilton.function_modifiers import cache + + +def expand_node() -> htypes.Parallelizable[int]: + for i in (0, 1, 2, 3, 4, 5, 6, 7): + yield i + + +@cache(format="json") +def inside_branch(expand_node: int) -> dict: + return {"value": expand_node} + + +def collect_node(inside_branch: htypes.Collect[dict]) -> list: + return list(inside_branch) diff --git a/tests/test_hamilton_driver.py b/tests/test_hamilton_driver.py index 23fe8ac71..0be0543e8 100644 --- a/tests/test_hamilton_driver.py +++ b/tests/test_hamilton_driver.py @@ -4,6 +4,7 @@ import pytest from hamilton import base, node +from hamilton.caching.adapter import HamiltonCacheAdapter from hamilton.driver import ( Builder, Driver, @@ -623,6 +624,23 @@ def test_materialize_checks_required_input(tmp_path): ) +def test_cache_raise_if_setting_twice(tmp_path): + builder = Builder() + + builder.with_cache(path=tmp_path) + # case 1: .with_cache() then .with_cache() + with pytest.raises(ValueError): + builder.with_cache(path=tmp_path) + # case 2: .with_cache() then adding SmartCacheAdapter() + with pytest.raises(ValueError): + builder.with_adapters(HamiltonCacheAdapter(path=tmp_path)) + # case 3: add SmartCacheAdapter() then .with_cache() + builder = Builder() + builder.with_adapters(HamiltonCacheAdapter(path=tmp_path)) + with pytest.raises(ValueError): + builder.with_cache() + + def test_validate_execution_happy(): dr = Builder().with_modules(tests.resources.very_simple_dag).build() dr.validate_execution(["b"], inputs={"a": 1})