Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache completions #389

Open
abrichr opened this issue Nov 22, 2024 · 2 comments
Open

Cache completions #389

abrichr opened this issue Nov 22, 2024 · 2 comments

Comments

@abrichr
Copy link

abrichr commented Nov 22, 2024

Is there some mechanism to avoid hitting the API if the prompt hasn't changed at all?

For example:

import ell

@ell.simple(model="gpt-4o")
def hello(name: str):
    """You are a helpful assistant.""" # System prompt
    return f"Say hello to {name}!" # User prompt

greeting = hello("Sam Altman")
print(greeting)

If we run this script twice, there is no need for the API to be called on the second time if we simply persist the result of the function call to disk.

Normally we can accomplish this with joblib.memory:

from joblib import Memory
import ell

memory = Memory("./cache")

@memory.cache()
@ell.simple(model="gpt-4o")
def hello(name: str):
    """You are a helpful assistant.""" # System prompt
    return f"Say hello to {name}!" # User prompt

greeting = hello("Sam Altman")
print(greeting)

Now if we run this script twice, the API will not be hit on the second call.

This behaves as we expect if we modify the parameters to the function, e.g. if we call hello("Sam"), the API will be hit, since the arguments changed.

However, if we change the prompt literal inside the function, unfortunately joblib is not able to pick up on it, and the stale result is returned.

Any suggestions for avoiding unnecessary API calls would be appreciated!

@abrichr
Copy link
Author

abrichr commented Nov 22, 2024

Workaround: incorporate a hash of the function’s source code (including the prompt) into the cache key.

import hashlib
import inspect
from joblib import Memory
import ell

memory = Memory("./cache")

def hash_source_code(func):
    """Hash the entire function's source code, including its docstring."""
    source = inspect.getsource(func)
    return hashlib.sha256(source.encode("utf-8")).hexdigest()

def cache_source(func):
    """Decorator to cache a function, including its source code in the hash."""

    def wrapper(*args, **kwargs):
        # Compute the hash of the source code
        func_hash = hash_source_code(func)
        cache_key = (func_hash, args, frozenset(kwargs.items()))

        # Define the cacheable function
        @memory.cache()
        def cached_call(func_hash, func_args, func_kwargs):
            return func(*func_args, **func_kwargs)

        return cached_call(func_hash, args, kwargs)

    return wrapper

# Example usage
@cache_source
@ell.simple(model="gpt-4o")
def hello(name: str):
    """You are a helpful assistant."""
    return f"Say hello to {name}!"

greeting = hello("Sam Altman")
print(greeting)

Now, modifying the prompt will avoid re-using the stale API result.

@gwpl
Copy link
Contributor

gwpl commented Dec 3, 2024

cross-linking: #200

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants