Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
swyxio committed Apr 13, 2023
1 parent ff977f4 commit 66ffbc4
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions scripts/memory/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
import datetime
import chromadb
from memory.base import MemoryProviderSingleton, get_embedding
from chromadb.errors import NoIndexException

class LocalCache(MemoryProviderSingleton):

# on load, load our database
def __init__(self, cfg) -> None:
self.persistence = cfg.memory_directory
if os.path.exists(self.persistence):
if self.persistence is not None and os.path.exists(self.persistence):
self.chromaClient = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet", # duckdb+parquet = persisted, duckdb = in-memory
persist_directory=self.persistence
))
else:
# in memory
print(f"Warning: The directory '{self.persistence}' does not exist. Chroma memory would not be saved to a file.")
self.chromaClient = chromadb.Client()
self.chromaCollection = self.chromaClient.create_collection(name="autoGPT_collection")
self.chromaCollection = self.chromaClient.create_collection(name="autogpt")
# we will key off of cfg.openai_embeddings_model to determine if using sentence transformers or openai embeddings
self.useOpenAIEmbeddings = True if (cfg.openai_embeddings_model) else False

Expand Down Expand Up @@ -50,7 +50,7 @@ def clear(self) -> str:

chroma_client = self.chromaClient
chroma_client.reset()
self.chromaCollection = chroma_client.create_collection(name="autoGPT_collection")
self.chromaCollection = chroma_client.create_collection(name="autogpt")
return "Obliviated"

def get(self, data: str) -> Optional[List[Any]]:
Expand All @@ -65,30 +65,34 @@ def get(self, data: str) -> Optional[List[Any]]:
results = None
if self.useOpenAIEmbeddings:
embeddings = get_embedding(data)
results = self.collection.query(
results = self.chromaCollection.query(
query_embeddings=[data],
n_results=1
)
else:
results = self.collection.query(
results = self.chromaCollection.query(
query_texts=[data],
n_results=1
)
return results

def get_relevant(self, text: str, k: int) -> List[Any]:
results = None
if self.useOpenAIEmbeddings:
embeddings = get_embedding(data)
results = self.collection.query(
query_embeddings=[data],
n_results=k
)
else:
results = self.collection.query(
query_texts=[data],
n_results=k
)
try:
if self.useOpenAIEmbeddings:
embeddings = get_embedding(text)
results = self.chromaCollection.query(
query_embeddings=[text],
n_results=min(k, self.chromaCollection.count())
)
else:
results = self.chromaCollection.query(
query_texts=[text],
n_results=min(k, self.chromaCollection.count())
)
except NoIndexException:
# print("No index found - suppressed because this is a common issue for first-run users")
pass
return results

def get_stats(self):
Expand Down

0 comments on commit 66ffbc4

Please sign in to comment.