diff --git a/langchain/utilities/__init__.py b/langchain/utilities/__init__.py index e8e88f3b723f2..4ba7c28d3cc46 100644 --- a/langchain/utilities/__init__.py +++ b/langchain/utilities/__init__.py @@ -3,7 +3,7 @@ from langchain.requests import RequestsWrapper from langchain.utilities.bash import BashProcess from langchain.utilities.bing_search import BingSearchAPIWrapper -from langchain.utilities.imun import ImunAPIWrapper, ImunMultiAPIWrapper +from langchain.utilities.imun import ImunAPIWrapper, ImunMultiAPIWrapper, ImunCache from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.searx_search import SearxSearchWrapper @@ -22,4 +22,5 @@ "BingSearchAPIWrapper", "ImunAPIWrapper", "ImunMultiAPIWrapper", + "ImunCache", ] diff --git a/langchain/utilities/imun.py b/langchain/utilities/imun.py index cbb866342f837..004d0a67374ea 100644 --- a/langchain/utilities/imun.py +++ b/langchain/utilities/imun.py @@ -4,7 +4,7 @@ https://azure.microsoft.com/en-us/products/cognitive-services/computer-vision """ import time -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional import io import imagesize @@ -363,6 +363,16 @@ def create_prompt(results: Dict) -> str: answer += IMUN_PROMPT_CELEBS.format(celebs=_concat_objects(celebrities, size=size)) return answer +class ImunCache(BaseModel): + cache: Optional[dict] = {} #: :meta private: + class Config: + copy_on_model_validation = 'none' + + def get(self, key:str)->dict: + return self.cache.get(key) + + def set(self, key:str, value:dict): + self.cache[key] = value class ImunAPIWrapper(BaseModel): """Wrapper for Image Understanding API. @@ -371,7 +381,7 @@ class ImunAPIWrapper(BaseModel): https://azure.microsoft.com/en-us/products/cognitive-services/computer-vision """ - cache: dict #: :meta private: + cache: Optional[ImunCache] #: :meta private: imun_subscription_key: str imun_url: str params: dict # "api-version=2023-02-01-preview&features=denseCaptions,Tags" @@ -383,9 +393,14 @@ class Config: def _imun_results(self, img_url: str) -> dict: param_str = '&'.join([f'{k}={v}' for k,v in self.params.items()]) - key = f"{self.imun_url}?{param_str}&data={img_url}" - if key in self.cache: - return self.cache[key] + key = f"{self.imun_url}?{param_str}" + img_cache = self.cache.get(img_url) + if img_cache: + if key in img_cache: + return img_cache[key] + else: + img_cache = {} + self.cache.set(img_url, img_cache) results = {"task": []} if "celebrities" in self.imun_url: results["task"].append("celebrities") @@ -489,7 +504,7 @@ def _imun_results(self, img_url: str) -> dict: results["languages"] = languages if _is_handwritten(analyzeResult["styles"]): results["words_style"] = "handwritten " - self.cache[key] = results + img_cache[key] = results return results @root_validator(pre=True)