Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Chroma to augment and make LocalCache memory/embeddings pluggable #1027

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
328c5e0
add initial attempt
swyxio Apr 13, 2023
701de56
quick fix
swyxio Apr 13, 2023
ff977f4
add config
swyxio Apr 13, 2023
66ffbc4
fix bugs
swyxio Apr 13, 2023
4c8aaa5
Merge branch 'master' of https://github.com/Torantulino/Auto-GPT into…
swyxio Apr 13, 2023
9cd4185
Merge branch 'master' of https://github.com/Torantulino/Auto-GPT into…
swyxio Apr 14, 2023
a6928c2
couple other things broke
swyxio Apr 14, 2023
3b97834
Merge branch 'master' of https://github.com/Torantulino/Auto-GPT into…
swyxio Apr 15, 2023
2e6c5d8
Merge branch 'master' of https://github.com/Torantulino/Auto-GPT into…
swyxio Apr 16, 2023
581aafe
solve merge conflicts
swyxio Apr 16, 2023
2d82d24
fix merge
swyxio Apr 16, 2023
c4bb6e1
fix merge
swyxio Apr 16, 2023
53180a6
fix merge
swyxio Apr 16, 2023
36ade6a
Merge branch 'master' into swyx/extendLocalCache
swyxio Apr 16, 2023
751f7a8
Merge branch 'master' of https://github.com/Torantulino/Auto-GPT into…
swyxio Apr 16, 2023
fbd6735
remove MEMORY_DIRECTORY config
swyxio Apr 16, 2023
6197c04
Merge branch 'swyx/extendLocalCache'
swyxio Apr 16, 2023
e4f483d
Merge branch 'Significant-Gravitas:master' into swyx/extendLocalCache
swyxio Apr 16, 2023
52290bd
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 16, 2023
81dc423
fix flake tests
swyxio Apr 16, 2023
5b8bb22
run test
swyxio Apr 16, 2023
0490c67
add to gitignore
swyxio Apr 16, 2023
d0ef2be
fix flake
swyxio Apr 16, 2023
d735ae3
Merge branch 'Significant-Gravitas:master' into swyx/extendLocalCache
swyxio Apr 16, 2023
0a482bf
undo workspace change
swyxio Apr 16, 2023
914305e
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 16, 2023
01b97cb
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 16, 2023
68465f9
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 16, 2023
d6817a6
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 17, 2023
bd1f752
Merge branch 'Significant-Gravitas:master' into swyx/extendLocalCache
swyxio Apr 17, 2023
8dee616
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 17, 2023
9dff547
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 18, 2023
e7d5e85
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 18, 2023
bb8f589
Merge branch 'Significant-Gravitas:master' into swyx/extendLocalCache
swyxio Apr 18, 2023
0e4df58
Merge branch 'Significant-Gravitas:master' into swyx/extendLocalCache
swyxio Apr 19, 2023
01fe911
fix minor bugs with chroma retrieval format for autogpt
swyxio Apr 19, 2023
7e5f5e6
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 19, 2023
9b3ffef
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 20, 2023
52c59f1
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 21, 2023
10fca1d
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 21, 2023
f0824b6
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 21, 2023
65a40ed
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 24, 2023
bf2b9aa
update broken apis
swyxio Apr 24, 2023
01dc63e
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 24, 2023
009abd7
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 25, 2023
2b18249
Merge branch 'swyx/extendLocalCache' of https://github.com/sw-yx/Auto…
swyxio Apr 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ log-ingestion.txt
logs
*.log
*.mp3
localCache/*
mem.sqlite3

# Byte-compiled / optimized / DLL files
Expand Down
6 changes: 2 additions & 4 deletions README.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions autogpt/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self) -> None:
)

self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.openai_embeddings_model = os.getenv("OPENAI_EMBEDDINGS_MODEL")
self.temperature = float(os.getenv("TEMPERATURE", "0"))
self.use_azure = os.getenv("USE_AZURE") == "True"
self.execute_local_commands = (
Expand Down
4 changes: 1 addition & 3 deletions autogpt/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def create_chat_completion(
resp = plugin.on_response(resp)
return resp


def get_ada_embedding(text: str) -> List[float]:
def get_embedding(text: str, model="text-embedding-ada-002") -> List[float]:
"""Get an embedding from the ada model.

Args:
Expand All @@ -219,7 +218,6 @@ def get_ada_embedding(text: str) -> List[float]:
List[float]: The embedding.
"""
cfg = Config()
model = "text-embedding-ada-002"
text = text.replace("\n", " ")

if cfg.use_azure:
Expand Down
151 changes: 78 additions & 73 deletions autogpt/memory/local.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,14 @@
from __future__ import annotations

import dataclasses
from pathlib import Path
from typing import Any, List

import numpy as np
import orjson

from autogpt.llm_utils import get_ada_embedding
from typing import Any, List, Optional, Tuple
import os
import uuid
import datetime
import chromadb
import logging
from autogpt.llm_utils import create_embedding
from autogpt.memory.base import MemoryProviderSingleton

EMBED_DIM = 1536
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS


def create_default_embeddings():
return np.zeros((0, EMBED_DIM)).astype(np.float32)


@dataclasses.dataclass
class CacheContent:
texts: List[str] = dataclasses.field(default_factory=list)
embeddings: np.ndarray = dataclasses.field(
default_factory=create_default_embeddings
)
from autogpt.llm_utils import create_embedding
from chromadb.errors import NoIndexException
from chromadb.config import Settings


class LocalCache(MemoryProviderSingleton):
Expand All @@ -38,16 +23,22 @@ def __init__(self, cfg) -> None:
Returns:
None
"""
workspace_path = Path(cfg.workspace_path)
self.filename = workspace_path / f"{cfg.memory_index}.json"

self.filename.touch(exist_ok=True)
# switch chroma's logging defaults to error only
logging.getLogger('chromadb').setLevel(logging.ERROR)

file_content = b"{}"
with self.filename.open("w+b") as f:
f.write(file_content)
# disable huggingface/tokenizers warning
import os
os.environ['TOKENIZERS_PARALLELISM'] = "True"

self.data = CacheContent()

self.chromaClient = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet", # this makes it persisted, comment out to be purely in-memory
persist_directory="localCache"
))
self.chromaCollection = self.chromaClient.create_collection(name="autogpt")
# we will key off of cfg.openai_embeddings_model to determine if using sentence transformers or openai embeddings
self.useOpenAIEmbeddings = True if (cfg.openai_embeddings_model) else False

def add(self, text: str):
"""
Expand All @@ -61,32 +52,34 @@ def add(self, text: str):
"""
if "Command Error:" in text:
return ""
self.data.texts.append(text)

embedding = get_ada_embedding(text)

vector = np.array(embedding).astype(np.float32)
vector = vector[np.newaxis, :]
self.data.embeddings = np.concatenate(
[
self.data.embeddings,
vector,
],
axis=0,
)

with open(self.filename, "wb") as f:
out = orjson.dumps(self.data, option=SAVE_OPTIONS)
f.write(out)

current_time = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
metadata = {"time_added": current_time}
if self.useOpenAIEmbeddings:
embeddings = create_embedding(text)
self.chromaCollection.add(
embeddings=[embeddings],
ids=[str(uuid.uuid4())],
metadatas=[metadata]
)
else:
self.chromaCollection.add(
documents=[text],
ids=[str(uuid.uuid4())],
metadatas=[metadata]
)
return text

def clear(self) -> str:
"""
Clears the data in memory.

Returns: A message indicating that the memory has been cleared.
Returns: A message indicating that the db has been cleared.
"""
self.data = CacheContent()

chroma_client = self.chromaClient
chroma_client.reset()
self.chromaCollection = chroma_client.create_collection(name="autogpt")
return "Obliviated"

def get(self, data: str) -> list[Any] | None:
Expand All @@ -98,29 +91,41 @@ def get(self, data: str) -> list[Any] | None:

Returns: The most relevant data.
"""
return self.get_relevant(data, 1)

def get_relevant(self, text: str, k: int) -> list[Any]:
""" "
matrix-vector mult to find score-for-each-row-of-matrix
get indices for top-k winning scores
return texts for those indices
Args:
text: str
k: int

Returns: List[str]
"""
embedding = get_ada_embedding(text)

scores = np.dot(self.data.embeddings, embedding)

top_k_indices = np.argsort(scores)[-k:][::-1]

return [self.data.texts[i] for i in top_k_indices]
results = None
if self.useOpenAIEmbeddings:
embeddings = create_embedding(data)
results = self.chromaCollection.query(
query_embeddings=[data],
n_results=1
)
else:
results = self.chromaCollection.query(
query_texts=[data],
n_results=1
)
return results['documents']

def get_relevant(self, text: str, k: int) -> List[Any]:
results = None
try:
if self.useOpenAIEmbeddings:
embeddings = create_embedding(text)
results = self.chromaCollection.query(
query_embeddings=[text],
n_results=min(k, self.chromaCollection.count())
)
else:
results = self.chromaCollection.query(
query_texts=[text],
n_results=min(k, self.chromaCollection.count())
)
except NoIndexException:
# print("No index found - suppressed because this is a common issue for first-run users")
pass
return results['documents']

def get_stats(self) -> tuple[int, tuple[int, ...]]:
"""
Returns: The stats of the local cache.
Returns: The number of items that have been added to the Chroma collection
"""
return len(self.data.texts), self.data.embeddings.shape
return self.chromaCollection.count()
6 changes: 3 additions & 3 deletions autogpt/memory/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections

from autogpt.config import Config
from autogpt.llm_utils import get_ada_embedding
from autogpt.llm_utils import get_embedding
from autogpt.memory.base import MemoryProviderSingleton


Expand Down Expand Up @@ -99,7 +99,7 @@ def add(self, data) -> str:
Returns:
str: log.
"""
embedding = get_ada_embedding(data)
embedding = get_embedding(data)
result = self.collection.insert([[embedding], [data]])
_text = (
"Inserting data into memory at primary key: "
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_relevant(self, data: str, num_relevant: int = 5):
list: The top-k relevant data.
"""
# search the embedding and return the most relevant text.
embedding = get_ada_embedding(data)
embedding = get_embedding(data)
search_params = {
"metrics_type": "IP",
"params": {"nprobe": 8},
Expand Down
7 changes: 3 additions & 4 deletions autogpt/memory/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pinecone
from colorama import Fore, Style

from autogpt.llm_utils import get_ada_embedding
from autogpt.logs import logger
from autogpt.memory.base import MemoryProviderSingleton
from autogpt.llm_utils import create_embedding


class PineconeMemory(MemoryProviderSingleton):
Expand Down Expand Up @@ -44,7 +43,7 @@ def __init__(self, cfg):
self.index = pinecone.Index(table_name)

def add(self, data):
vector = get_ada_embedding(data)
vector = create_embedding(data)
# no metadata here. We may wish to change that long term.
self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
Expand All @@ -64,7 +63,7 @@ def get_relevant(self, data, num_relevant=5):
:param data: The data to compare to.
:param num_relevant: The number of relevant data to return. Defaults to 5
"""
query_embedding = get_ada_embedding(data)
query_embedding = create_embedding(data)
results = self.index.query(
query_embedding, top_k=num_relevant, include_metadata=True
)
Expand Down
6 changes: 3 additions & 3 deletions autogpt/memory/redismem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query

from autogpt.llm_utils import get_ada_embedding
from autogpt.logs import logger
from autogpt.memory.base import MemoryProviderSingleton
from autogpt.llm_utils import create_embedding

SCHEMA = [
TextField("data"),
Expand Down Expand Up @@ -88,7 +88,7 @@ def add(self, data: str) -> str:
"""
if "Command Error:" in data:
return ""
vector = get_ada_embedding(data)
vector = create_embedding(data)
vector = np.array(vector).astype(np.float32).tobytes()
data_dict = {b"data": data, "embedding": vector}
pipe = self.redis.pipeline()
Expand Down Expand Up @@ -130,7 +130,7 @@ def get_relevant(self, data: str, num_relevant: int = 5) -> list[Any] | None:

Returns: A list of the most relevant data.
"""
query_embedding = get_ada_embedding(data)
query_embedding = create_embedding(data)
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
query = (
Query(base_query)
Expand Down
6 changes: 3 additions & 3 deletions autogpt/memory/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5

from autogpt.llm_utils import get_ada_embedding
from autogpt.llm_utils import get_embedding
from autogpt.memory.base import MemoryProviderSingleton


Expand Down Expand Up @@ -70,7 +70,7 @@ def _build_auth_credentials(self, cfg):
return None

def add(self, data):
vector = get_ada_embedding(data)
vector = get_embedding(data)

doc_uuid = generate_uuid5(data, self.index)
data_object = {"raw_text": data}
Expand Down Expand Up @@ -99,7 +99,7 @@ def clear(self):
return "Obliterated"

def get_relevant(self, data, num_relevant=5):
query_embedding = get_ada_embedding(data)
query_embedding = get_embedding(data)
try:
results = (
self.client.query.get(self.index, ["raw_text"])
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ duckduckgo-search
google-api-python-client #(https://developers.google.com/custom-search/v1/overview)
pinecone-client==2.2.1
redis
chromadb
orjson
Pillow
selenium==4.1.4
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/weaviate_memory_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from weaviate.util import get_valid_uuid

from autogpt.config import Config
from autogpt.llm_utils import get_ada_embedding
from autogpt.llm_utils import get_embedding
from autogpt.memory.weaviate import WeaviateMemory


Expand Down Expand Up @@ -75,7 +75,7 @@ def test_get(self):
uuid=get_valid_uuid(uuid4()),
data_object={"raw_text": doc},
class_name=self.index,
vector=get_ada_embedding(doc),
vector=get_embedding(doc),
)

batch.flush()
Expand Down