Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 18, 2024
1 parent e1ef209 commit 3b71f19
Show file tree
Hide file tree
Showing 3 changed files with 523 additions and 397 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def __init__(self, dimension: int) -> None:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]

async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)

def embed_query(self, text: str) -> list[float]:
try:
vals = json.loads(text)
Expand All @@ -48,9 +45,6 @@ def embed_query(self, text: str) -> list[float]:
assert len(vals) == self.dimension
return vals

async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)


@pytest.fixture
def embedding_d2() -> Embeddings:
Expand All @@ -59,14 +53,13 @@ def embedding_d2() -> Embeddings:

class EarthEmbeddings(Embeddings):
def get_vector_near(self, value: float) -> List[float]:
return [value + (random.random() / 100.0), value - (random.random() / 100.0)]
base_point = [value, (1 - value**2) ** 0.5]
fluctuation = random.random() / 100.0
return [base_point[0] + fluctuation, base_point[1] - fluctuation]

def embed_documents(self, texts: list[str]) -> list[list[float]]:
return [self.embed_query(txt) for txt in texts]

async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)

def embed_query(self, text: str) -> list[float]:
words = set(text.lower().split())
if "earth" in words:
Expand All @@ -77,9 +70,6 @@ def embed_query(self, text: str) -> list[float]:
vector = self.get_vector_near(0.1)
return vector

async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)


def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]:
return [doc.id for doc in docs]
Expand Down Expand Up @@ -149,9 +139,20 @@ def graph_vector_store_docs() -> list[Document]:
return docs_a + docs_b + docs_f + docs_t


class CassandraSession:
table_name: str
session: Any

def __init__(self, table_name: str, session: Any):
self.table_name = table_name
self.session = session


@contextmanager
def cassandra_session(table_name: str, drop: bool = True) -> Generator[Any, None, None]:
# Initialize the Cassandra cluster and session
def get_cassandra_session(
table_name: str, drop: bool = True
) -> Generator[CassandraSession, None, None]:
"""Initialize the Cassandra cluster and session"""
from cassandra.cluster import Cluster

if "CASSANDRA_CONTACT_POINTS" in os.environ:
Expand All @@ -167,20 +168,18 @@ def cassandra_session(table_name: str, drop: bool = True) -> Generator[Any, None
session = cluster.connect()

try:
# Ensure keyspace exists
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
" WITH replication = "
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
)
)
# Drop table if required
if drop:
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")

# Yield the session for usage
yield session
yield CassandraSession(table_name=table_name, session=session)
finally:
# Ensure proper shutdown/cleanup of resources
session.shutdown()
Expand All @@ -191,38 +190,38 @@ def cassandra_session(table_name: str, drop: bool = True) -> Generator[Any, None
def graph_vector_store_angular(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with cassandra_session(table_name=table_name) as session:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=AngularTwoDimensionalEmbeddings(),
session=session,
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=table_name,
table_name=session.table_name,
)


@pytest.fixture(scope="function")
def graph_vector_store_earth(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with cassandra_session(table_name=table_name) as session:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=EarthEmbeddings(),
session=session,
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=table_name,
table_name=session.table_name,
)


@pytest.fixture(scope="function")
def graph_vector_store_fake(
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with cassandra_session(table_name=table_name) as session:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=FakeEmbeddings(),
session=session,
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=table_name,
table_name=session.table_name,
)


Expand All @@ -231,12 +230,12 @@ def graph_vector_store_d2(
embedding_d2: Embeddings,
table_name: str = "graph_test_table",
) -> Generator[CassandraGraphVectorStore, None, None]:
with cassandra_session(table_name=table_name) as session:
with get_cassandra_session(table_name=table_name) as session:
yield CassandraGraphVectorStore(
embedding=embedding_d2,
session=session,
session=session.session,
keyspace=TEST_KEYSPACE,
table_name=table_name,
table_name=session.table_name,
)


Expand Down Expand Up @@ -620,20 +619,20 @@ def test_gvs_add_nodes_sync(
Link(kind="kC", direction="in", tag="tC"),
]
nodes = [
Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0),
Node(text="[0, 1]", metadata={"m": 1}, links=links1),
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
]
graph_vector_store_d2.add_nodes(nodes)
hits = graph_vector_store_d2.similarity_search_by_vector([0, 3])
hits = graph_vector_store_d2.similarity_search_by_vector([0.9, 0.1])
assert len(hits) == 2
assert hits[0].id == "id0"
assert hits[0].page_content == "[0, 2]"
assert hits[0].page_content == "[1, 0]"
md0 = hits[0].metadata
assert md0["m"] == "0.0"
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")

assert hits[1].id != "id0"
assert hits[1].page_content == "[0, 1]"
assert hits[1].page_content == "[-1, 0]"
md1 = hits[1].metadata
assert md1["m"] == "1.0"
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
Expand All @@ -651,21 +650,21 @@ async def test_gvs_add_nodes_async(
Link(kind="kC", direction="in", tag="tC"),
]
nodes = [
Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0),
Node(text="[0, 1]", metadata={"m": 1}, links=links1),
Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0),
Node(text="[-1, 0]", metadata={"m": 1}, links=links1),
]
async for _ in graph_vector_store_d2.aadd_nodes(nodes):
pass

hits = await graph_vector_store_d2.asimilarity_search_by_vector([0, 3])
hits = await graph_vector_store_d2.asimilarity_search_by_vector([0.9, 0.1])
assert len(hits) == 2
assert hits[0].id == "id0"
assert hits[0].page_content == "[0, 2]"
assert hits[0].page_content == "[1, 0]"
md0 = hits[0].metadata
assert md0["m"] == "0.0"
assert any(isinstance(v, set) for k, v in md0.items() if k != "m")
assert hits[1].id != "id0"
assert hits[1].page_content == "[0, 1]"
assert hits[1].page_content == "[-1, 0]"
md1 = hits[1].metadata
assert md1["m"] == "1.0"
assert any(isinstance(v, set) for k, v in md1.items() if k != "m")
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,25 @@ async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)


def _embedding_d2() -> Embeddings:
@pytest.fixture
def embedding_d2() -> Embeddings:
return ParserEmbeddings(dimension=2)


class CassandraSession:
table_name: str
session: Any

def __init__(self, table_name: str, session: Any):
self.table_name = table_name
self.session = session


@contextmanager
def cassandra_session(table_name: str, drop: bool = True) -> Generator[Any, None, None]:
# Initialize the Cassandra cluster and session
def get_cassandra_session(
table_name: str, drop: bool = True
) -> Generator[CassandraSession, None, None]:
"""Initialize the Cassandra cluster and session"""
from cassandra.cluster import Cluster

if "CASSANDRA_CONTACT_POINTS" in os.environ:
Expand All @@ -74,20 +86,18 @@ def cassandra_session(table_name: str, drop: bool = True) -> Generator[Any, None
session = cluster.connect()

try:
# Ensure keyspace exists
session.execute(
(
f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}"
" WITH replication = "
"{'class': 'SimpleStrategy', 'replication_factor': 1}"
)
)
# Drop table if required
if drop:
session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}")

# Yield the session for usage
yield session
yield CassandraSession(table_name=table_name, session=session)
finally:
# Ensure proper shutdown/cleanup of resources
session.shutdown()
Expand All @@ -96,49 +106,40 @@ def cassandra_session(table_name: str, drop: bool = True) -> Generator[Any, None

@contextmanager
def vector_store(
embedding: Embeddings,
table_name: str,
setup_mode: SetupMode,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
drop: bool = True,
) -> Generator[Cassandra, None, None]:
# Open a session with the context manager
with cassandra_session(table_name=table_name, drop=drop) as session:
try:
# Yield the Cassandra instance with the open session
yield Cassandra(
table_name=table_name,
keyspace=TEST_KEYSPACE,
session=session, # Pass the session to Cassandra
embedding=_embedding_d2(),
setup_mode=setup_mode,
metadata_indexing=metadata_indexing,
)
finally:
# Cleanup happens in cassandra_session context manager automatically
pass
with get_cassandra_session(table_name=table_name, drop=drop) as session:
yield Cassandra(
table_name=session.table_name,
keyspace=TEST_KEYSPACE,
session=session.session,
embedding=embedding,
setup_mode=setup_mode,
metadata_indexing=metadata_indexing,
)


@contextmanager
def graph_vector_store(
embedding: Embeddings,
table_name: str,
setup_mode: SetupMode,
metadata_deny_list: Optional[list[str]] = None,
drop: bool = True,
) -> Generator[CassandraGraphVectorStore, None, None]:
# Open a session with the context manager
with cassandra_session(table_name=table_name, drop=drop) as session:
try:
yield CassandraGraphVectorStore(
table_name=table_name,
keyspace=TEST_KEYSPACE,
session=session,
embedding=_embedding_d2(),
setup_mode=setup_mode,
metadata_deny_list=metadata_deny_list,
)
finally:
# Cleanup happens in cassandra_session context manager automatically
pass
with get_cassandra_session(table_name=table_name, drop=drop) as session:
yield CassandraGraphVectorStore(
table_name=session.table_name,
keyspace=TEST_KEYSPACE,
session=session.session,
embedding=embedding,
setup_mode=setup_mode,
metadata_deny_list=metadata_deny_list,
)


def _vs_indexing_policy(table_name: str) -> Union[Tuple[str, Iterable[str]], str]:
Expand Down Expand Up @@ -177,6 +178,7 @@ class TestUpgradeToGraphVectorStore:
def test_upgrade_to_gvs_success_sync(
self,
*,
embedding_d2: Embeddings,
gvs_setup_mode: SetupMode,
table_name: str,
gvs_metadata_deny_list: list[str],
Expand All @@ -186,6 +188,7 @@ def test_upgrade_to_gvs_success_sync(

# Create vector store using SetupMode.SYNC
with vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=SetupMode.SYNC,
metadata_indexing=_vs_indexing_policy(table_name=table_name),
Expand All @@ -202,6 +205,7 @@ def test_upgrade_to_gvs_success_sync(
# Create a GRAPH Vector Store using the existing collection from above
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
with graph_vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=gvs_setup_mode,
metadata_deny_list=gvs_metadata_deny_list,
Expand All @@ -226,6 +230,7 @@ def test_upgrade_to_gvs_success_sync(
async def test_upgrade_to_gvs_success_async(
self,
*,
embedding_d2: Embeddings,
gvs_setup_mode: SetupMode,
table_name: str,
gvs_metadata_deny_list: list[str],
Expand All @@ -235,6 +240,7 @@ async def test_upgrade_to_gvs_success_async(

# Create vector store using SetupMode.ASYNC
with vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=SetupMode.ASYNC,
metadata_indexing=_vs_indexing_policy(table_name=table_name),
Expand All @@ -251,6 +257,7 @@ async def test_upgrade_to_gvs_success_async(
# Create a GRAPH Vector Store using the existing collection from above
# with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy
with graph_vector_store(
embedding=embedding_d2,
table_name=table_name,
setup_mode=gvs_setup_mode,
metadata_deny_list=gvs_metadata_deny_list,
Expand Down
Loading

0 comments on commit 3b71f19

Please sign in to comment.