Skip to content

Commit df7ddd4

Browse files
authored
Merge pull request #372 from BillSchumacher/redis-backend
Implement Local Cache and Redis Memory backend. Removes dependence on Pinecone
2 parents 1ea9c36 + a861dec commit df7ddd4

File tree

10 files changed

+391
-24
lines changed

10 files changed

+391
-24
lines changed

README.md

+33
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,39 @@ export CUSTOM_SEARCH_ENGINE_ID="YOUR_CUSTOM_SEARCH_ENGINE_ID"
141141
142142
```
143143

144+
## Redis Setup
145+
146+
Install docker desktop.
147+
148+
Run:
149+
```
150+
docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest
151+
```
152+
153+
Set the following environment variables:
154+
```
155+
MEMORY_BACKEND=redis
156+
REDIS_HOST=localhost
157+
REDIS_PORT=6379
158+
REDIS_PASSWORD=
159+
```
160+
161+
Note that this is not intended to be run facing the internet and is not secure, do not expose redis to the internet without a password or at all really.
162+
163+
You can optionally set
164+
165+
```
166+
WIPE_REDIS_ON_START=False
167+
```
168+
169+
To persist memory stored in Redis.
170+
171+
You can specify the memory index for redis using the following:
172+
173+
````
174+
MEMORY_INDEX=whatever
175+
````
176+
144177
## 🌲 Pinecone API Key Setup
145178

146179
Pinecone enable a vector based memory so a vast memory can be stored and only relevant memories

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@ docker
1212
duckduckgo-search
1313
google-api-python-client #(https://developers.google.com/custom-search/v1/overview)
1414
pinecone-client==2.2.1
15+
redis
16+
orjson
1517
Pillow

scripts/commands.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import browse
22
import json
3-
from memory import PineconeMemory
3+
from memory import get_memory
44
import datetime
55
import agent_manager as agents
66
import speak
@@ -53,10 +53,11 @@ def get_command(response):
5353

5454

5555
def execute_command(command_name, arguments):
56-
memory = PineconeMemory()
56+
memory = get_memory(cfg)
57+
5758
try:
5859
if command_name == "google":
59-
60+
6061
# Check if the Google API key is set and use the official search method
6162
# If the API key is not set or has only whitespaces, use the unofficial search method
6263
if cfg.google_api_key and (cfg.google_api_key.strip() if cfg.google_api_key else None):

scripts/config.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import abc
12
import os
23
import openai
34
from dotenv import load_dotenv
45
# Load environment variables from .env file
56
load_dotenv()
67

78

8-
class Singleton(type):
9+
class Singleton(abc.ABCMeta, type):
910
"""
1011
Singleton metaclass for ensuring only one instance of a class.
1112
"""
@@ -20,6 +21,10 @@ def __call__(cls, *args, **kwargs):
2021
return cls._instances[cls]
2122

2223

24+
class AbstractSingleton(abc.ABC, metaclass=Singleton):
25+
pass
26+
27+
2328
class Config(metaclass=Singleton):
2429
"""
2530
Configuration class to store the state of bools for different scripts access.
@@ -59,7 +64,14 @@ def __init__(self):
5964
# User agent headers to use when browsing web
6065
# Some websites might just completely deny request with an error code if no user agent was found.
6166
self.user_agent_header = {"User-Agent":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36"}
62-
67+
self.redis_host = os.getenv("REDIS_HOST", "localhost")
68+
self.redis_port = os.getenv("REDIS_PORT", "6379")
69+
self.redis_password = os.getenv("REDIS_PASSWORD", "")
70+
self.wipe_redis_on_start = os.getenv("WIPE_REDIS_ON_START", "True") == 'True'
71+
self.memory_index = os.getenv("MEMORY_INDEX", 'auto-gpt')
72+
# Note that indexes must be created on db 0 in redis, this is not configureable.
73+
74+
self.memory_backend = os.getenv("MEMORY_BACKEND", 'local')
6375
# Initialize the OpenAI API client
6476
openai.api_key = self.openai_api_key
6577

scripts/main.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import random
33
import commands as cmd
4-
from memory import PineconeMemory
4+
from memory import get_memory
55
import data
66
import chat
77
from colorama import Fore, Style
@@ -281,12 +281,9 @@ def parse_arguments():
281281
# Make a constant:
282282
user_input = "Determine which next command to use, and respond using the format specified above:"
283283

284-
# raise an exception if pinecone_api_key or region is not provided
285-
if not cfg.pinecone_api_key or not cfg.pinecone_region: raise Exception("Please provide pinecone_api_key and pinecone_region")
286284
# Initialize memory and make sure it is empty.
287285
# this is particularly important for indexing and referencing pinecone memory
288-
memory = PineconeMemory()
289-
memory.clear()
286+
memory = get_memory(cfg, init=True)
290287
print('Using memory of type: ' + memory.__class__.__name__)
291288

292289
# Interaction Loop

scripts/memory/__init__.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from memory.local import LocalCache
2+
try:
3+
from memory.redismem import RedisMemory
4+
except ImportError:
5+
print("Redis not installed. Skipping import.")
6+
RedisMemory = None
7+
8+
try:
9+
from memory.pinecone import PineconeMemory
10+
except ImportError:
11+
print("Pinecone not installed. Skipping import.")
12+
PineconeMemory = None
13+
14+
15+
def get_memory(cfg, init=False):
16+
memory = None
17+
if cfg.memory_backend == "pinecone":
18+
if not PineconeMemory:
19+
print("Error: Pinecone is not installed. Please install pinecone"
20+
" to use Pinecone as a memory backend.")
21+
else:
22+
memory = PineconeMemory(cfg)
23+
if init:
24+
memory.clear()
25+
elif cfg.memory_backend == "redis":
26+
if not RedisMemory:
27+
print("Error: Redis is not installed. Please install redis-py to"
28+
" use Redis as a memory backend.")
29+
else:
30+
memory = RedisMemory(cfg)
31+
32+
if memory is None:
33+
memory = LocalCache(cfg)
34+
if init:
35+
memory.clear()
36+
return memory
37+
38+
39+
__all__ = [
40+
"get_memory",
41+
"LocalCache",
42+
"RedisMemory",
43+
"PineconeMemory",
44+
]

scripts/memory/base.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Base class for memory providers."""
2+
import abc
3+
from config import AbstractSingleton
4+
import openai
5+
6+
7+
def get_ada_embedding(text):
8+
text = text.replace("\n", " ")
9+
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
10+
11+
12+
class MemoryProviderSingleton(AbstractSingleton):
13+
@abc.abstractmethod
14+
def add(self, data):
15+
pass
16+
17+
@abc.abstractmethod
18+
def get(self, data):
19+
pass
20+
21+
@abc.abstractmethod
22+
def clear(self):
23+
pass
24+
25+
@abc.abstractmethod
26+
def get_relevant(self, data, num_relevant=5):
27+
pass
28+
29+
@abc.abstractmethod
30+
def get_stats(self):
31+
pass

scripts/memory/local.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import dataclasses
2+
import orjson
3+
from typing import Any, List, Optional
4+
import numpy as np
5+
import os
6+
from memory.base import MemoryProviderSingleton, get_ada_embedding
7+
8+
9+
EMBED_DIM = 1536
10+
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS
11+
12+
13+
def create_default_embeddings():
14+
return np.zeros((0, EMBED_DIM)).astype(np.float32)
15+
16+
17+
@dataclasses.dataclass
18+
class CacheContent:
19+
texts: List[str] = dataclasses.field(default_factory=list)
20+
embeddings: np.ndarray = dataclasses.field(
21+
default_factory=create_default_embeddings
22+
)
23+
24+
25+
class LocalCache(MemoryProviderSingleton):
26+
27+
# on load, load our database
28+
def __init__(self, cfg) -> None:
29+
self.filename = f"{cfg.memory_index}.json"
30+
if os.path.exists(self.filename):
31+
with open(self.filename, 'rb') as f:
32+
loaded = orjson.loads(f.read())
33+
self.data = CacheContent(**loaded)
34+
else:
35+
self.data = CacheContent()
36+
37+
def add(self, text: str):
38+
"""
39+
Add text to our list of texts, add embedding as row to our
40+
embeddings-matrix
41+
42+
Args:
43+
text: str
44+
45+
Returns: None
46+
"""
47+
if 'Command Error:' in text:
48+
return ""
49+
self.data.texts.append(text)
50+
51+
embedding = get_ada_embedding(text)
52+
53+
vector = np.array(embedding).astype(np.float32)
54+
vector = vector[np.newaxis, :]
55+
self.data.embeddings = np.concatenate(
56+
[
57+
vector,
58+
self.data.embeddings,
59+
],
60+
axis=0,
61+
)
62+
63+
with open(self.filename, 'wb') as f:
64+
out = orjson.dumps(
65+
self.data,
66+
option=SAVE_OPTIONS
67+
)
68+
f.write(out)
69+
return text
70+
71+
def clear(self) -> str:
72+
"""
73+
Clears the redis server.
74+
75+
Returns: A message indicating that the memory has been cleared.
76+
"""
77+
self.data = CacheContent()
78+
return "Obliviated"
79+
80+
def get(self, data: str) -> Optional[List[Any]]:
81+
"""
82+
Gets the data from the memory that is most relevant to the given data.
83+
84+
Args:
85+
data: The data to compare to.
86+
87+
Returns: The most relevant data.
88+
"""
89+
return self.get_relevant(data, 1)
90+
91+
def get_relevant(self, text: str, k: int) -> List[Any]:
92+
""""
93+
matrix-vector mult to find score-for-each-row-of-matrix
94+
get indices for top-k winning scores
95+
return texts for those indices
96+
Args:
97+
text: str
98+
k: int
99+
100+
Returns: List[str]
101+
"""
102+
embedding = get_ada_embedding(text)
103+
104+
scores = np.dot(self.data.embeddings, embedding)
105+
106+
top_k_indices = np.argsort(scores)[-k:][::-1]
107+
108+
return [self.data.texts[i] for i in top_k_indices]
109+
110+
def get_stats(self):
111+
"""
112+
Returns: The stats of the local cache.
113+
"""
114+
return len(self.data.texts), self.data.embeddings.shape

scripts/memory.py renamed to scripts/memory/pinecone.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
1-
from config import Config, Singleton
2-
import pinecone
3-
import openai
4-
5-
cfg = Config()
6-
7-
8-
def get_ada_embedding(text):
9-
text = text.replace("\n", " ")
10-
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
111

2+
import pinecone
123

13-
def get_text_from_embedding(embedding):
14-
return openai.Embedding.retrieve(embedding, model="text-embedding-ada-002")["data"][0]["text"]
4+
from memory.base import MemoryProviderSingleton, get_ada_embedding
155

166

17-
class PineconeMemory(metaclass=Singleton):
18-
def __init__(self):
7+
class PineconeMemory(MemoryProviderSingleton):
8+
def __init__(self, cfg):
199
pinecone_api_key = cfg.pinecone_api_key
2010
pinecone_region = cfg.pinecone_region
2111
pinecone.init(api_key=pinecone_api_key, environment=pinecone_region)

0 commit comments

Comments
 (0)