Skip to content

Commit

Permalink
Fixes cache issue from 703 and 679 (#707)
Browse files Browse the repository at this point in the history
* Avoid creating cache database when cache_seed is None (which disables cache)

* Add some debugging.

* Removed some debugging.

* Update autogen/oai/client.py

Co-authored-by: Chi Wang <[email protected]>

* Fixed missing filter function logic from oai/client.py

---------

Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
afourney and sonichi authored Nov 20, 2023
1 parent 143e49c commit ef1c3d3
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,10 @@ def yes_or_no_filter(context, response):
cache_seed = extra_kwargs.get("cache_seed", 41)
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
if cache_seed is not None:

# Try to load the response from cache
if cache_seed is not None:
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
# Try to get the response from cache
key = get_key(params)
response = cache.get(key, None)
Expand All @@ -240,17 +242,28 @@ def yes_or_no_filter(context, response):
response.pass_filter = pass_filter
# TODO: add response.cost
return response
try:
response = self._completions_create(client, params)
except APIError:
logger.debug(f"config {i} failed", exc_info=1)
if i == last:
raise
else:
if cache_seed is not None:
# Cache the response
continue # filter is not passed; try the next config
try:
response = self._completions_create(client, params)
except APIError:
logger.debug(f"config {i} failed", exc_info=1)
if i == last:
raise
else:
if cache_seed is not None:
# Cache the response
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
cache.set(key, response)

# check the filter
pass_filter = filter_func is None or filter_func(context=context, response=response)
if pass_filter or i == last:
# Return the response if it passes the filter or it is the last client
response.config_id = i
response.pass_filter = pass_filter
# TODO: add response.cost
return response
continue # filter is not passed; try the next config

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
Expand Down

0 comments on commit ef1c3d3

Please sign in to comment.