Skip to content

Commit

Permalink
Add support for Pinecone
Browse files Browse the repository at this point in the history
  • Loading branch information
homanp committed Jan 16, 2024
1 parent 4d7e45a commit a93ec58
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 23 deletions.
2 changes: 1 addition & 1 deletion api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ async def query(payload: RequestPayload):
index_name=payload.index_name, credentials=payload.vector_database
)
chunks = await vector_service.query(input=payload.input, top_k=4)
documents = await vector_service.convert_to_dict(points=chunks)
documents = await vector_service.convert_to_dict(chunks=chunks)
results = await vector_service.rerank(query=payload.input, documents=documents)
return {"success": True, "data": results}
8 changes: 7 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ distro==1.9.0
dnspython==2.4.2
fastapi==0.109.0
fastavro==1.9.3
filelock==3.13.1
frozenlist==1.4.1
fsspec==2023.12.2
greenlet==3.0.3
Expand All @@ -30,12 +31,16 @@ hpack==4.0.0
httpcore==1.0.2
httptools==0.6.1
httpx==0.26.0
huggingface-hub==0.20.2
hyperframe==6.0.1
idna==3.6
importlib-metadata==6.11.0
Jinja2==3.1.3
joblib==1.3.2
litellm==1.17.5
llama-index==0.9.30
loguru==0.7.2
MarkupSafe==2.1.3
marshmallow==3.20.2
multidict==6.0.4
mypy-extensions==1.0.0
Expand All @@ -47,7 +52,7 @@ openai==1.7.2
packaging==23.2
pandas==2.1.4
pathspec==0.12.1
pinecone-client==2.2.4
pinecone-client==3.0.0
platformdirs==4.1.0
portalocker==2.8.2
protobuf==4.25.2
Expand All @@ -72,6 +77,7 @@ SQLAlchemy==2.0.25
starlette==0.35.1
tenacity==8.2.3
tiktoken==0.5.2
tokenizers==0.15.0
toml==0.10.2
tqdm==4.66.1
typing-inspect==0.9.0
Expand Down
57 changes: 36 additions & 21 deletions service/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from litellm import embedding
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
from pinecone import Pinecone, ServerlessSpec

from models.vector_database import VectorDatabase


Expand Down Expand Up @@ -54,33 +56,46 @@ def __init__(self, index_name: str, dimension: int, credentials: dict):
super().__init__(
index_name=index_name, dimension=dimension, credentials=credentials
)
pinecone.init(
api_key=credentials["PINECONE_API_KEY"],
environment=credentials["PINECONE_ENVIRONMENT"],
)
# Create a new vector index if it doesn't
# exist dimensions should be passed in the arguments
if index_name not in pinecone.list_indexes():
pinecone = Pinecone(api_key=credentials["api_key"])
if index_name not in [index.name for index in pinecone.list_indexes()]:
pinecone.create_index(
name=index_name, metric="cosine", shards=1, dimension=dimension
name=self.index_name,
dimension=1024,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-west-2"),
)
self.index = pinecone.Index(index_name=self.index_name)
self.index = pinecone.Index(name=self.index_name)

async def convert_to_dict(self, documents: list):
pass
async def convert_to_dict(self, chunks: List):
docs = [
{
"content": chunk.get("metadata")["content"],
"page_label": chunk.get("metadata")["page_label"],
"file_url": chunk.get("metadata")["file_url"],
}
for chunk in chunks
]
return docs

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]):
self.index.upsert(vectors=embeddings)

async def query(
self, queries: List[ndarray], top_k: int, include_metadata: bool = True
):
async def query(self, input: str, top_k: 4, include_metadata: bool = True):
vectors = []
embedding_object = embedding(
model="huggingface/intfloat/multilingual-e5-large",
input=input,
api_key=config("HUGGINGFACE_API_KEY"),
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
results = self.index.query(
queries=queries,
vector=vectors,
top_k=top_k,
include_metadata=include_metadata,
)
return results["results"][0]["matches"]
return results["matches"]


class QdrantService(VectorService):
Expand All @@ -105,14 +120,14 @@ def __init__(self, index_name: str, dimension: int, credentials: dict):
),
)

async def convert_to_dict(self, points: List[rest.PointStruct]):
async def convert_to_dict(self, chunks: List[rest.PointStruct]):
docs = [
{
"content": point.payload.get("content"),
"page_label": point.payload.get("page_label"),
"file_url": point.payload.get("file_url"),
"content": chunk.payload.get("content"),
"page_label": chunk.payload.get("page_label"),
"file_url": chunk.payload.get("file_url"),
}
for point in points
for chunk in chunks
]
return docs

Expand Down

0 comments on commit a93ec58

Please sign in to comment.