Skip to content

Commit

Permalink
Fix doc lookup (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Jan 6, 2025
1 parent 7f6ed79 commit 9efa066
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 23 deletions.
5 changes: 4 additions & 1 deletion lumen/ai/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,10 @@ def _add_document(
file, file_extension=document_controls.extension
).text_content

metadata = f"Filename: {document_controls.filename} " + document_controls._metadata_input.value
metadata = {
"filename": document_controls.filename,
"comments": document_controls._metadata_input.value,
}
document = {"text": text, "metadata": metadata}
if "document_sources" in self._memory:
self._memory["document_sources"].append(document)
Expand Down
6 changes: 4 additions & 2 deletions lumen/ai/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import re

from typing import Any

import param
Expand Down Expand Up @@ -68,10 +70,10 @@ def __init__(self, **params):
def _update_vector_store(self, _, __, sources):
for source in sources:
if not self.vector_store.query(source["text"], threshold=1):
self.vector_store.add([{"text": source["text"], "metadata": source.get("metadata", "")}])
self.vector_store.add([{"text": source["text"], "metadata": source.get("metadata", {})}])

async def respond(self, messages: list[Message], **kwargs: Any) -> str:
query = messages[-1]["content"]
query = re.findall(r"'(.*?)'", messages[-1]["content"])[0]
results = self.vector_store.query(query, top_k=self.n, threshold=self.min_similarity)
closest_doc_chunks = [
f"{result['text']} (Relevance: {result['similarity']:.1f} - Metadata: {result['metadata']}"
Expand Down
66 changes: 46 additions & 20 deletions lumen/ai/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def query(
sorted_indices = np.argsort(similarities)[::-1]
for idx in sorted_indices:
similarity = similarities[idx]
if similarity < threshold:
if similarity <= threshold:
continue
results.append(
{
Expand Down Expand Up @@ -346,7 +346,7 @@ def _setup_database(self) -> None:
"""Set up the DuckDB database with necessary tables and indexes."""
self.connection.execute("INSTALL 'vss';")
self.connection.execute("LOAD 'vss';")

self.connection.execute("SET hnsw_enable_experimental_persistence = true;")
self.connection.execute("CREATE SEQUENCE IF NOT EXISTS documents_id_seq;")

self.connection.execute(
Expand Down Expand Up @@ -433,19 +433,37 @@ def query(

base_query = f"""
SELECT id, text, metadata,
1 - array_distance(embedding, ?::FLOAT[{self.vocab_size}], 'cosine') AS similarity
array_cosine_similarity(embedding, ?::REAL[{self.vocab_size}]) AS similarity
FROM documents
WHERE 1=1
"""
params = [query_embedding]

if filters:
for key, value in filters.items():
# Use json_extract_string for string comparison
base_query += f" AND json_extract_string(metadata, '$.{key}') = ?"
params.append(str(value))
if isinstance(value, list):
if not all(isinstance(v, (str,int,float)) for v in value):
print(f"Invalid value in filter {key}. Can only filter by string, integer or float.")
return []
placeholders = ", ".join("?" for _ in value)
base_query += (
f" AND json_extract_string(metadata, '$.{key}') IS NOT NULL "
f"AND json_extract_string(metadata, '$.{key}') IN ({placeholders})"
)
params.extend([str(v) for v in value])
elif isinstance(value, (str,int,float)):
# Single value equality check
base_query += (
f" AND json_extract_string(metadata, '$.{key}') IS NOT NULL "
f"AND json_extract_string(metadata, '$.{key}') = ?"
)
params.append(str(value))
else:
print(f"Invalid value in filter {key}. Can only filter by string, integer or float.")
return []

base_query += f" AND 1 - array_distance(embedding, ?::FLOAT[{self.vocab_size}], 'cosine') >= ?"

base_query += f" AND array_cosine_similarity(embedding, ?::REAL[{self.vocab_size}]) >= ?"
params.extend([query_embedding, threshold])

base_query += """
Expand All @@ -454,17 +472,21 @@ def query(
"""
params.append(top_k)

result = self.connection.execute(base_query, params).fetchall()
try:
result = self.connection.execute(base_query, params).fetchall()

return [
{
"id": row[0],
"text": row[1],
"metadata": json.loads(row[2]),
"similarity": row[3],
}
for row in result
]
return [
{
"id": row[0],
"text": row[1],
"metadata": json.loads(row[2]),
"similarity": row[3],
}
for row in result
]
except duckdb.Error as e:
print(f"Error during query: {e}")
return []

def filter_by(
self, filters: dict, limit: int | None = None, offset: int = 0
Expand Down Expand Up @@ -525,6 +547,10 @@ def delete(self, ids: list[int]) -> None:
self.connection.execute(query, ids)

def clear(self) -> None:
"""Clear all items from the DuckDB vector store."""
self.connection.execute("DELETE FROM documents;")
self.connection.execute("ALTER SEQUENCE documents_id_seq RESTART WITH 1;")
"""
Clear all entries from both tables and reset sequence by dropping/recreating everything.
"""
self.connection.execute("DROP TABLE IF EXISTS metadata_embeddings;")
self.connection.execute("DROP TABLE IF EXISTS documents;")
self.connection.execute("DROP SEQUENCE IF EXISTS documents_id_seq;")
self._setup_database()
251 changes: 251 additions & 0 deletions lumen/tests/ai/test_vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import pytest

try:
import lumen.ai # noqa
except ModuleNotFoundError:
pytest.skip("lumen.ai could not be imported, skipping tests.", allow_module_level=True)

from lumen.ai.embeddings import NumpyEmbeddings
from lumen.ai.vector_store import DuckDBVectorStore, NumpyVectorStore


@pytest.mark.xdist_group("vss")
class VectorStoreTestKit:
"""
A base class (test kit) that provides the *common* tests and fixture definitions.
"""

@pytest.fixture
def store(self):
"""
Must be overridden in the subclass to return a fresh store instance.
"""
raise NotImplementedError("Subclasses must override `store` fixture")

@pytest.fixture
def empty_store(self, store):
"""
Returns an empty store (clears whatever the `store` fixture gave us).
"""
store.clear()
return store

@pytest.fixture
def store_with_three_docs(self, store):
"""
Returns a store preloaded with three documents for convenience.
"""
items = [
{
"text": "Food: $10, Drinks: $5, Total: $15",
"metadata": {"title": "receipt", "department": "accounting"},
},
{
"text": "In the org chart, the CEO reports to the board.",
"metadata": {"title": "org_chart", "department": "management"},
},
{
"text": "A second receipt with different details",
"metadata": {"title": "receipt", "department": "accounting"},
},
]
store.add(items)
return store

def test_filter_by_title_receipt(self, store_with_three_docs):
filtered = store_with_three_docs.filter_by({"title": "receipt"})
assert len(filtered) == 2

def test_filter_by_department_management(self, store_with_three_docs):
filtered_mgmt = store_with_three_docs.filter_by({"department": "management"})
assert len(filtered_mgmt) == 1
assert filtered_mgmt[0]["metadata"]["title"] == "org_chart"

def test_filter_by_nonexistent_filter(self, store_with_three_docs):
filtered_none = store_with_three_docs.filter_by({"title": "does_not_exist"})
assert len(filtered_none) == 0

def test_filter_by_limit_and_offset(self, store_with_three_docs):
filtered_limited = store_with_three_docs.filter_by({"title": "receipt"}, limit=1)
assert len(filtered_limited) == 1

filtered_offset = store_with_three_docs.filter_by({"title": "receipt"}, limit=1, offset=1)
assert len(filtered_offset) == 1
assert filtered_offset[0]["text"] == "A second receipt with different details"

def test_query_with_filter_title_org_chart(self, store_with_three_docs):
query_text = "CEO reports to?"
results = store_with_three_docs.query(query_text)
assert len(results) >= 1
assert results[0]["metadata"]["title"] == "org_chart"
assert "CEO" in results[0]["text"]

results = store_with_three_docs.query(query_text, filters={"title": "org_chart"})
assert len(results) == 1
assert results[0]["metadata"]["title"] == "org_chart"
assert "CEO" in results[0]["text"]

def test_query_empty_store(self, empty_store):
results = empty_store.query("some query")
assert results == []

def test_filter_empty_store(self, empty_store):
filtered = empty_store.filter_by({"key": "value"})
assert filtered == []

def test_delete_empty_store(self, empty_store):
empty_store.delete([1, 2, 3])
results = empty_store.query("some query")
assert results == []

def test_delete_empty_list(self, store_with_three_docs):
original_count = len(store_with_three_docs.filter_by({}))
store_with_three_docs.delete([])
after_count = len(store_with_three_docs.filter_by({}))
assert original_count == after_count

def test_delete_specific_ids(self, store_with_three_docs):
all_docs = store_with_three_docs.filter_by({})
target_id = all_docs[0]["id"]
store_with_three_docs.delete([target_id])

remaining_docs = store_with_three_docs.filter_by({})
remaining_ids = [doc["id"] for doc in remaining_docs]
assert target_id not in remaining_ids

def test_clear_store(self, store_with_three_docs):
store_with_three_docs.clear()
results = store_with_three_docs.filter_by({})
assert len(results) == 0

def test_add_docs_without_metadata(self, empty_store):
items = [
{"text": "Document one with no metadata"},
{"text": "Document two with no metadata"},
]
ids = empty_store.add(items)
assert len(ids) == 2

results = empty_store.query("Document one")
assert len(results) > 0

def test_add_docs_with_nonstring_metadata(self, empty_store):
items = [
{
"text": "Doc with int metadata",
"metadata": {"count": 42, "description": None},
},
{
"text": "Doc with list metadata",
"metadata": {"tags": ["foo", "bar"]},
},
]
empty_store.add(items)

results = empty_store.filter_by({})
assert len(results) == 2

def test_add_long_text_chunking(self, empty_store):
"""
Verifies that adding a document with text longer than `chunk_size` is chunked properly.
"""
# For demonstration, set a small chunk_size
empty_store.chunk_size = 50 # force chunking fairly quickly

# Create a long text, repeated enough times to exceed the chunk_size
long_text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " * 10
items = [
{
"text": long_text,
"metadata": {"title": "long_document"},
}
]
ids = empty_store.add(items)

# Should have multiple chunks, hence multiple IDs
# (The exact number depends on chunk_size & text length.)
assert len(ids) > 1, "Expected more than one chunk/ID due to forced chunking"

def test_query_long_text_chunking(self, empty_store):
"""
Verifies querying a store containing a large text still returns sensible results.
"""
empty_store.chunk_size = 60
long_text = " ".join(["word"] * 300) # 300 words, ensures multiple chunks

items = [
{
"text": long_text,
"metadata": {"title": "very_large_document"},
}
]
empty_store.add(items)

# Query for a word we know is in the text.
# Not a perfect test since this is a mocked embedding, but ensures no errors.
results = empty_store.query("word", top_k=3)
assert len(results) <= 3, "We requested top_k=3"
assert len(results) > 0, "Should have at least one chunk matching"

# The top result should logically be from our big text.
top_result = results[0]
assert "very_large_document" in str(top_result.get("metadata")), (
"Expected the big doc chunk to show up in the results"
)

def test_add_multiple_large_documents(self, empty_store):
"""
Verifies behavior when multiple large documents are added.
"""
# Force chunking to a smaller chunk_size
empty_store.chunk_size = 100

doc1 = "Doc1 " + ("ABC " * 200)
doc2 = "Doc2 " + ("XYZ " * 200)
items = [
{
"text": doc1,
"metadata": {"title": "large_document_1"},
},
{
"text": doc2,
"metadata": {"title": "large_document_2"},
},
]

ids = empty_store.add(items)
# At least more than 2 chunks, since each doc is forced to chunk
assert len(ids) > 2, "Expected more than 2 chunks total for two large docs"

# Query something from doc2
results = empty_store.query("XYZ", top_k=5)
assert len(results) <= 5
# Expect at least 1 chunk from doc2
# This is a simplistic check, but ensures chunking doesn't break queries
found_doc2 = any("large_document_2" in r["metadata"].get("title", "") for r in results)
assert found_doc2, "Expected to find at least one chunk belonging to doc2"

# Query something from doc1
results = empty_store.query("ABC", top_k=5)
assert len(results) <= 5
found_doc1 = any("large_document_1" in r["metadata"].get("title", "") for r in results)
assert found_doc1, "Expected to find at least one chunk belonging to doc1"


class TestNumpyVectorStore(VectorStoreTestKit):

@pytest.fixture
def store(self):
store = NumpyVectorStore(embeddings=NumpyEmbeddings())
store.clear()
return store


class TestDuckDBVectorStore(VectorStoreTestKit):

@pytest.fixture
def store(self, tmp_path):
db_path = str(tmp_path / "test_duckdb.db")
store = DuckDBVectorStore(uri=db_path, embeddings=NumpyEmbeddings())
store.clear()
return store

0 comments on commit 9efa066

Please sign in to comment.