Skip to content

Commit

Permalink
Refactor: Add pytest fixtures for memory_chatbot_graph tests and impr…
Browse files Browse the repository at this point in the history
…ove test structure
  • Loading branch information
ogabrielluiz committed Aug 5, 2024
1 parent 4239ab6 commit 46fec80
Showing 1 changed file with 88 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from collections import deque

import pytest

from langflow.components.helpers.Memory import MemoryComponent
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.models.OpenAIModel import OpenAIModelComponent
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.components.prompts.Prompt import PromptComponent
from langflow.graph import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.graph.schema import GraphDump


@pytest.fixture
def client():
pass

def test_memory_chatbot():

@pytest.fixture
def memory_chatbot_graph():
session_id = "test_session_id"
template = """{context}
Expand All @@ -32,10 +41,87 @@ def test_memory_chatbot():
chat_output.set(input_value=openai_component.text_response)

graph = Graph(chat_input, chat_output)
return graph


def test_memory_chatbot(memory_chatbot_graph):
# Now we run step by step
expected_order = deque(["chat_input", "chat_memory", "prompt", "openai", "chat_output"])
for step in expected_order:
result = graph.step()
result = memory_chatbot_graph.step()
if isinstance(result, Finish):
break
assert step == result.vertex.id


def test_memory_chatbot_dump_structure(memory_chatbot_graph: Graph):
# Now we run step by step
graph_dict = memory_chatbot_graph.dump(
name="Memory Chatbot", description="A memory chatbot", endpoint_name="membot"
)
assert isinstance(graph_dict, dict)
# Test structure
assert "data" in graph_dict
assert "is_component" in graph_dict

data_dict = graph_dict["data"]
assert "nodes" in data_dict
assert "edges" in data_dict
assert "description" in graph_dict
assert "endpoint_name" in graph_dict

# Test data
nodes = data_dict["nodes"]
edges = data_dict["edges"]
description = graph_dict["description"]
endpoint_name = graph_dict["endpoint_name"]

assert len(nodes) == 5
assert len(edges) == 4
assert description is not None
assert endpoint_name is not None


def test_memory_chatbot_dump_components_and_edges(memory_chatbot_graph: Graph):
# Check all components and edges were dumped correctly
graph_dict: GraphDump = memory_chatbot_graph.dump(
name="Memory Chatbot", description="A memory chatbot", endpoint_name="membot"
)

data_dict = graph_dict["data"]
nodes = data_dict["nodes"]
edges = data_dict["edges"]

# sort the nodes by id
nodes = sorted(nodes, key=lambda x: x["id"])

# Check each node
assert nodes[0]["data"]["type"] == "ChatInput"
assert nodes[0]["id"] == "chat_input"

assert nodes[1]["data"]["type"] == "MemoryComponent"
assert nodes[1]["id"] == "chat_memory"

assert nodes[2]["data"]["type"] == "ChatOutput"
assert nodes[2]["id"] == "chat_output"

assert nodes[3]["data"]["type"] == "OpenAIModelComponent"
assert nodes[3]["id"] == "openai"

assert nodes[4]["data"]["type"] == "PromptComponent"
assert nodes[4]["id"] == "prompt"

# Check edges
expected_edges = [
("chat_input", "prompt"),
("chat_memory", "prompt"),
("prompt", "openai"),
("openai", "chat_output"),
]

assert len(edges) == len(expected_edges)

for edge in edges:
source = edge["source"]
target = edge["target"]
assert (source, target) in expected_edges, edge

0 comments on commit 46fec80

Please sign in to comment.