Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 18, 2024
1 parent e1ef209 commit 847daa7
Show file tree
Hide file tree
Showing 2 changed files with 450 additions and 345 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def __init__(self, dimension: int) -> None:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]

async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)

def embed_query(self, text: str) -> list[float]:
try:
vals = json.loads(text)
Expand All @@ -48,9 +45,6 @@ def embed_query(self, text: str) -> list[float]:
assert len(vals) == self.dimension
return vals

async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)


@pytest.fixture
def embedding_d2() -> Embeddings:
Expand All @@ -59,14 +53,13 @@ def embedding_d2() -> Embeddings:

class EarthEmbeddings(Embeddings):
def get_vector_near(self, value: float) -> List[float]:
return [value + (random.random() / 100.0), value - (random.random() / 100.0)]
base_point = [value, (1 - value**2) ** 0.5]
fluctuation = random.random() / 100.0
return [base_point[0] + fluctuation, base_point[1] - fluctuation]

def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]

async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)

def embed_query(self, text: str) -> list[float]:
words = set(text.lower().split())
if "earth" in words:
Expand All @@ -77,9 +70,6 @@ def embed_query(self, text: str) -> list[float]:
vector = self.get_vector_near(0.1)
return vector

async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)


def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]:
return [doc.id for doc in docs]
Expand Down Expand Up @@ -620,20 +610,20 @@ def test_gvs_add_nodes_sync(
Link(kind="kC", direction="in", tag="tC"),
]
nodes = [
Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0),
Node(text="[0, 1]", metadata={"m": 1}, links=links1),
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
]
graph_vector_store_d2.add_nodes(nodes)
hits = graph_vector_store_d2.similarity_search_by_vector([0, 3])
hits = graph_vector_store_d2.similarity_search_by_vector([0.9, 0.1])
assert len(hits) == 2
assert hits[0].id == "id0"
assert hits[0].page_content == "[0, 2]"
assert hits[0].page_content == "[1, 0]"
md0 = hits[0].metadata
assert md0["m"] == "0.0"
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")

assert hits[1].id != "id0"
assert hits[1].page_content == "[0, 1]"
assert hits[1].page_content == "[-1, 0]"
md1 = hits[1].metadata
assert md1["m"] == "1.0"
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
Expand All @@ -651,21 +641,21 @@ async def test_gvs_add_nodes_async(
Link(kind="kC", direction="in", tag="tC"),
]
nodes = [
Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0),
Node(text="[0, 1]", metadata={"m": 1}, links=links1),
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
]
async for _ in graph_vector_store_d2.aadd_nodes(nodes):
pass

hits = await graph_vector_store_d2.asimilarity_search_by_vector([0, 3])
hits = await graph_vector_store_d2.asimilarity_search_by_vector([0.9, 0.1])
assert len(hits) == 2
assert hits[0].id == "id0"
assert hits[0].page_content == "[0, 2]"
assert hits[0].page_content == "[1, 0]"
md0 = hits[0].metadata
assert md0["m"] == "0.0"
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
assert hits[1].id != "id0"
assert hits[1].page_content == "[0, 1]"
assert hits[1].page_content == "[-1, 0]"
md1 = hits[1].metadata
assert md1["m"] == "1.0"
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
Loading

0 comments on commit 847daa7

Please sign in to comment.