Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 17, 2024
1 parent 4259378 commit 86fdb71
Show file tree
Hide file tree
Showing 3 changed files with 646 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,171 +101,6 @@ def embed_query(self, text: str) -> List[float]:
def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]:
return [doc.id for doc in docs]


def test_mmr_traversal() -> None:
"""
Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
the following vectors on a circle (numbered v0 through v3):
______ v2
/ \
/ | v1
v3 | . | query
| / v0
|______/ (N.B. very crude drawing)
With fetch_k==2 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order)
because v1 is "too close" to v0 (and v0 is closer than v1)).
Both v2 and v3 are reachable via edges from v0, so once it is
selected, those are both considered.
"""
g_store = _graphvectorstore_from_documents(
docs=[],
embedding=AngularTwoDimensionalEmbeddings(),
)

v0 = Node(
id="v0",
text="-0.124",
links=[
Link.outgoing(kind="explicit", tag="link"),
],
)
v1 = Node(
id="v1",
text="+0.127",
)
v2 = Node(
id="v2",
text="+0.25",
links=[
Link.incoming(kind="explicit", tag="link"),
],
)
v3 = Node(
id="v3",
text="+1.0",
links=[
Link.incoming(kind="explicit", tag="link"),
],
)
g_store.add_nodes([v0, v1, v2, v3])

results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
assert _result_ids(results) == ["v0", "v2"]

# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
# So it ends up picking "v1" even though it's similar to "v0".
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
assert _result_ids(results) == ["v0", "v1"]

# With max depth 0 but higher `fetch_k`, we encounter v2
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
assert _result_ids(results) == ["v0", "v2"]

# v0 score is .46, v2 score is 0.16 so it won't be chosen.
results = g_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
assert _result_ids(results) == ["v0"]

# with k=4 we should get all of the documents.
results = g_store.mmr_traversal_search("0.0", k=4)
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]


def test_write_retrieve_keywords() -> None:
greetings = Node(
id="greetings",
text="Typical Greetings",
links=[
Link.incoming(kind="parent", tag="parent"),
],
)

node1 = Node(
id="doc1",
text="Hello World",
links=[
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="world"),
],
)

node2 = Node(
id="doc2",
text="Hello Earth",
links=[
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="earth"),
],
)

g_store = _graphvectorstore_from_documents(
docs=[],
embedding=FakeEmbeddings(),
)

g_store.add_nodes(nodes=[greetings, node1, node2])

# Doc2 is more similar, but World and Earth are similar enough that doc1 also
# shows up.
results: Iterable[Document] = g_store.similarity_search("Earth", k=2)
assert _result_ids(results) == ["doc2", "doc1"]

results = g_store.similarity_search("Earth", k=1)
assert _result_ids(results) == ["doc2"]

results = g_store.traversal_search("Earth", k=2, depth=0)
assert _result_ids(results) == ["doc2", "doc1"]

results = g_store.traversal_search("Earth", k=2, depth=1)
assert _result_ids(results) == ["doc2", "doc1", "greetings"]

# K=1 only pulls in doc2 (Hello Earth)
results = g_store.traversal_search("Earth", k=1, depth=0)
assert _result_ids(results) == ["doc2"]

# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via
# keyword edge.
results = g_store.traversal_search("Earth", k=1, depth=1)
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}


def test_metadata() -> None:
g_store = _graphvectorstore_from_documents(
docs=[],
embedding=FakeEmbeddings(),
)

doc_a = Node(
id="a",
text="A",
metadata={"other": "some other field"},
links=[
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
],
)

g_store.add_nodes([doc_a])
results = g_store.similarity_search("A")
assert len(results) == 1
assert results[0].id == "a"
metadata = results[0].metadata
assert metadata["other"] == "some other field"
assert set(metadata[METADATA_LINKS_KEY]) == {
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
}


### NEW STUFF


def _graph_vector_store_docs() -> list[Document]:
"""
This is a set of Documents to pre-populate a graph vector g_store,
Expand Down Expand Up @@ -378,7 +213,6 @@ def _graphvectorstore_from_texts(
metadata_indexing=metadata_indexing,
)


async def _graphvectorstore_from_texts_async(
texts: List[str],
embedding: Embeddings,
Expand Down Expand Up @@ -420,7 +254,6 @@ def _graphvectorstore_from_documents(
metadata_indexing=metadata_indexing,
)


async def _graphvectorstore_from_documents_async(
docs: List[Document],
embedding: Embeddings,
Expand Down Expand Up @@ -459,6 +292,169 @@ def _populated_graph_vector_store_d2() -> CassandraGraphVectorStore:
return g_store


def test_mmr_traversal() -> None:
"""
Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
the following vectors on a circle (numbered v0 through v3):
______ v2
/ \
/ | v1
v3 | . | query
| / v0
|______/ (N.B. very crude drawing)
With fetch_k==2 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order)
because v1 is "too close" to v0 (and v0 is closer than v1)).
Both v2 and v3 are reachable via edges from v0, so once it is
selected, those are both considered.
"""
g_store = _graphvectorstore_from_documents(
docs=[],
embedding=AngularTwoDimensionalEmbeddings(),
)

v0 = Node(
id="v0",
text="-0.124",
links=[
Link.outgoing(kind="explicit", tag="link"),
],
)
v1 = Node(
id="v1",
text="+0.127",
)
v2 = Node(
id="v2",
text="+0.25",
links=[
Link.incoming(kind="explicit", tag="link"),
],
)
v3 = Node(
id="v3",
text="+1.0",
links=[
Link.incoming(kind="explicit", tag="link"),
],
)
g_store.add_nodes([v0, v1, v2, v3])

results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2)
assert _result_ids(results) == ["v0", "v2"]

# With max depth 0, no edges are traversed, so this doesn't reach v2 or v3.
# So it ends up picking "v1" even though it's similar to "v0".
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0)
assert _result_ids(results) == ["v0", "v1"]

# With max depth 0 but higher `fetch_k`, we encounter v2
results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0)
assert _result_ids(results) == ["v0", "v2"]

# v0 score is .46, v2 score is 0.16 so it won't be chosen.
results = g_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2)
assert _result_ids(results) == ["v0"]

# with k=4 we should get all of the documents.
results = g_store.mmr_traversal_search("0.0", k=4)
assert _result_ids(results) == ["v0", "v2", "v1", "v3"]


def test_write_retrieve_keywords() -> None:
greetings = Node(
id="greetings",
text="Typical Greetings",
links=[
Link.incoming(kind="parent", tag="parent"),
],
)

node1 = Node(
id="doc1",
text="Hello World",
links=[
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="world"),
],
)

node2 = Node(
id="doc2",
text="Hello Earth",
links=[
Link.outgoing(kind="parent", tag="parent"),
Link.bidir(kind="kw", tag="greeting"),
Link.bidir(kind="kw", tag="earth"),
],
)

g_store = _graphvectorstore_from_documents(
docs=[],
embedding=FakeEmbeddings(),
)

g_store.add_nodes(nodes=[greetings, node1, node2])

# Doc2 is more similar, but World and Earth are similar enough that doc1 also
# shows up.
results: Iterable[Document] = g_store.similarity_search("Earth", k=2)
assert _result_ids(results) == ["doc2", "doc1"]

results = g_store.similarity_search("Earth", k=1)
assert _result_ids(results) == ["doc2"]

results = g_store.traversal_search("Earth", k=2, depth=0)
assert _result_ids(results) == ["doc2", "doc1"]

results = g_store.traversal_search("Earth", k=2, depth=1)
assert _result_ids(results) == ["doc2", "doc1", "greetings"]

# K=1 only pulls in doc2 (Hello Earth)
results = g_store.traversal_search("Earth", k=1, depth=0)
assert _result_ids(results) == ["doc2"]

# K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via
# keyword edge.
results = g_store.traversal_search("Earth", k=1, depth=1)
assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"}


def test_metadata() -> None:
g_store = _graphvectorstore_from_documents(
docs=[],
embedding=FakeEmbeddings(),
)

doc_a = Node(
id="a",
text="A",
metadata={"other": "some other field"},
links=[
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
],
)

g_store.add_nodes([doc_a])
results = g_store.similarity_search("A")
assert len(results) == 1
assert results[0].id == "a"
metadata = results[0].metadata
assert metadata["other"] == "some other field"
assert set(metadata[METADATA_LINKS_KEY]) == {
Link.incoming(kind="hyperlink", tag="http://a"),
Link.bidir(kind="other", tag="foo"),
}




def test_gvs_similarity_search_sync() -> None:
"""Simple (non-graph) similarity search on a graph vector g_store."""
g_store = _populated_graph_vector_store_d2()
Expand Down
Loading

0 comments on commit 86fdb71

Please sign in to comment.