Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions backend/tests/apps/ai/agent/nodes_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import openai
import pytest

from apps.ai.agent.nodes import AgentNodes
from apps.ai.common.constants import DEFAULT_CHUNKS_RETRIEVAL_LIMIT, DEFAULT_SIMILARITY_THRESHOLD


class TestAgentNodes:
@pytest.fixture
def mock_openai(self, mocker):
mocker.patch("os.getenv", return_value="fake-key")
return mocker.patch("apps.ai.agent.nodes.openai.OpenAI")

@pytest.fixture
def nodes(self, mock_openai, mocker):
mocker.patch("apps.ai.agent.nodes.Retriever")
mocker.patch("apps.ai.agent.nodes.Generator")
return AgentNodes()

def test_init_raises_error_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
with pytest.raises(
ValueError, match="DJANGO_OPEN_AI_SECRET_KEY environment variable not set"
):
AgentNodes()

def test_retrieve_logic(self, nodes, mocker):
state = {"query": "test query"}

mock_metadata = {"entity_types": ["code"], "filters": {}, "requested_fields": []}
nodes.extract_query_metadata = mocker.Mock(return_value=mock_metadata)

nodes.retriever.retrieve.return_value = [{"text": "chunk1", "similarity": 0.9}]
nodes.filter_chunks_by_metadata = mocker.Mock(
return_value=[{"text": "chunk1", "similarity": 0.9}]
)

new_state = nodes.retrieve(state)

assert "context_chunks" in new_state
assert len(new_state["context_chunks"]) == 1
assert new_state["extracted_metadata"] == mock_metadata

nodes.retriever.retrieve.assert_called_with(
query="test query",
limit=DEFAULT_CHUNKS_RETRIEVAL_LIMIT,
similarity_threshold=DEFAULT_SIMILARITY_THRESHOLD,
content_types=["code"],
)

def test_retrieve_skips_if_chunks_present(self, nodes):
state = {"context_chunks": ["existing"]}
new_state = nodes.retrieve(state)
assert new_state == state

def test_generate_logic(self, nodes):
state = {"query": "test query", "context_chunks": []}
nodes.generator.generate_answer.return_value = "Generated answer"

new_state = nodes.generate(state)

assert new_state["answer"] == "Generated answer"
assert new_state["iteration"] == 1
assert len(new_state["history"]) == 1
assert new_state["history"][0]["answer"] == "Generated answer"

def test_evaluate_requires_more_context(self, nodes, mocker):
state = {"query": "test", "answer": "unsure", "extracted_metadata": {}}

mock_eval = {"requires_more_context": True, "feedback": "need more info"}
nodes.call_evaluator = mocker.Mock(return_value=mock_eval)

nodes.retriever.retrieve.return_value = ["new_chunk"]
nodes.filter_chunks_by_metadata = mocker.Mock(return_value=["new_chunk"])

new_state = nodes.evaluate(state)

assert new_state["feedback"] == "Expand and refine answer using newly retrieved context."
assert "context_chunks" in new_state
assert new_state["evaluation"] == mock_eval

def test_evaluate_complete(self, nodes, mocker):
state = {"query": "test", "answer": "good"}
mock_eval = {"requires_more_context": False, "feedback": None, "complete": True}
nodes.call_evaluator = mocker.Mock(return_value=mock_eval)

new_state = nodes.evaluate(state)
assert new_state["feedback"] is None
assert new_state["evaluation"] == mock_eval

def test_route_from_evaluation(self, nodes):
assert nodes.route_from_evaluation({"evaluation": {"complete": True}}) == "complete"
assert (
nodes.route_from_evaluation({"evaluation": {"complete": False}, "iteration": 0})
== "refine"
)
assert (
nodes.route_from_evaluation({"evaluation": {"complete": False}, "iteration": 100})
== "complete"
)

def test_filter_chunks_by_metadata(self, nodes):
chunks = [
{"text": "foo", "additional_context": {"lang": "python"}, "similarity": 0.8},
{"text": "bar", "additional_context": {"lang": "go"}, "similarity": 0.9},
]
metadata = {"filters": {"lang": "python"}, "requested_fields": []}

filtered = nodes.filter_chunks_by_metadata(chunks, metadata, limit=10)
assert filtered[0]["text"] == "foo"

def test_extract_query_metadata_openai_error(self, nodes, mocker):
mocker.patch(
"apps.ai.agent.nodes.Prompt.get_metadata_extractor_prompt", return_value="sys prompt"
)
nodes.openai_client.chat.completions.create.side_effect = openai.OpenAIError("Error")

metadata = nodes.extract_query_metadata("query")
assert metadata["intent"] == "general query"

def test_call_evaluator_openai_error(self, nodes, mocker):
nodes.generator.prepare_context.return_value = "ctx"
mocker.patch(
"apps.ai.agent.nodes.Prompt.get_evaluator_system_prompt", return_value="sys prompt"
)
nodes.openai_client.chat.completions.create.side_effect = openai.OpenAIError("Error")

eval_result = nodes.call_evaluator(query="q", answer="a", context_chunks=[])
assert eval_result["feedback"] == "Evaluator error or invalid response."
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import json
from http import HTTPStatus

import pytest
from django.http import HttpResponse
from django.test import RequestFactory

from apps.common.middlewares.block_null_characters import BlockNullCharactersMiddleware


class TestBlockNullCharactersMiddleware:
@pytest.fixture
def middleware(self):
def get_response(_request):
return HttpResponse("OK")

return BlockNullCharactersMiddleware(get_response)

@pytest.fixture
def factory(self):
return RequestFactory()

def test_clean_request_passes(self, middleware, factory):
request = factory.get("/clean/path")
response = middleware(request)
assert response.status_code == HTTPStatus.OK
assert response.content == b"OK"

def test_null_in_path_blocks(self, middleware, factory):
request = factory.get("/path/with/\x00/null")
response = middleware(request)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert json.loads(response.content) == {
"message": "Request contains null characters in URL or parameters "
"which are not allowed.",
"errors": {},
}

def test_null_in_query_params_blocks(self, middleware, factory):
request = factory.get("/clean/path", {"q": "bad\x00value"})
response = middleware(request)
assert response.status_code == HTTPStatus.BAD_REQUEST

def test_null_in_post_data_blocks(self, middleware, factory):
request = factory.post("/clean/path", {"data": "bad\x00value"})
response = middleware(request)
assert response.status_code == HTTPStatus.BAD_REQUEST

def test_null_in_body_blocks(self, middleware, factory):
request = factory.post(
"/clean/path",
data=b'{"key": "bad\x00value"}',
content_type="application/json",
)
response = middleware(request)
assert response.status_code == HTTPStatus.BAD_REQUEST
assert json.loads(response.content) == {
"message": "Request contains null characters in body which are not allowed.",
"errors": {},
}

def test_unicode_null_in_body_blocks(self, middleware, factory):
request = factory.post(
"/clean/path",
data=b'{"key": "bad\\u0000value"}',
content_type="application/json",
)
response = middleware(request)
assert response.status_code == HTTPStatus.BAD_REQUEST
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from apps.github.management.commands.github_update_pull_requests import Command


class TestGithubUpdatePullRequests:
def test_handle_links_issues(self, mocker):
mock_repo = mocker.Mock(name="Repository", id=1)
mock_repo.name = "test-repo"

mock_issue = mocker.Mock(name="Issue", id=10, number=123)
mock_issue.repository = mock_repo

mock_pr = mocker.Mock(name="PullRequest", id=100, number=456)
mock_pr.repository = mock_repo
mock_pr.body = "This closes #123"
mock_pr.related_issues = mocker.Mock()
mock_pr.related_issues.values_list.return_value = []

mock_pr_qs = mocker.Mock()
mock_pr_qs.select_related.return_value.all.return_value = [mock_pr]

mocker.patch(
"apps.github.management.commands.github_update_pull_requests.PullRequest.objects",
mock_pr_qs,
)

mock_issue_qs = mocker.Mock()
mock_issue_qs.filter.return_value = [mock_issue]
mocker.patch(
"apps.github.management.commands.github_update_pull_requests.Issue.objects",
mock_issue_qs,
)

command = Command()
command.stdout = mocker.Mock()
command.handle()

mock_issue_qs.filter.assert_called_with(repository=mock_repo, number__in={123})
mock_pr.related_issues.add.assert_called_with(10)

def test_handle_no_repo_skipped(self, mocker):
mock_pr = mocker.Mock(name="PullRequest", id=100, number=456)
mock_pr.repository = None
mock_pr.related_issues = mocker.Mock()

mock_pr_qs = mocker.Mock()
mock_pr_qs.select_related.return_value.all.return_value = [mock_pr]

mocker.patch(
"apps.github.management.commands.github_update_pull_requests.PullRequest.objects",
mock_pr_qs,
)

command = Command()
command.stdout = mocker.Mock()
command.handle()

mock_pr.related_issues.add.assert_not_called()

def test_handle_no_keywords(self, mocker):
mock_repo = mocker.Mock(name="Repository")
mock_pr = mocker.Mock(name="PullRequest", id=100, number=456)
mock_pr.repository = mock_repo
mock_pr.body = "Just a normal PR"

mock_pr_qs = mocker.Mock()
mock_pr_qs.select_related.return_value.all.return_value = [mock_pr]
mocker.patch(
"apps.github.management.commands.github_update_pull_requests.PullRequest.objects",
mock_pr_qs,
)

mock_issue_objects = mocker.patch(
"apps.github.management.commands.github_update_pull_requests.Issue.objects"
)

command = Command()
command.stdout = mocker.Mock()
command.handle()

mock_issue_objects.filter.assert_not_called()
77 changes: 77 additions & 0 deletions backend/tests/apps/github/models/comment_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from django.contrib.contenttypes.models import ContentType

from apps.github.models.comment import Comment
from apps.github.models.user import User


class TestComment:
def test_from_github_populates_fields(self, mocker):
comment = Comment()
gh_comment = mocker.Mock()
gh_comment.body = "Test body"
gh_comment.created_at = "2023-01-01T00:00:00Z"
gh_comment.updated_at = "2023-01-02T00:00:00Z"

mock_author = mocker.Mock(spec=User)
mock_author._state = mocker.Mock()

comment.from_github(gh_comment, author=mock_author)

assert comment.body == "Test body"
assert comment.created_at == "2023-01-01T00:00:00Z"
assert comment.updated_at == "2023-01-02T00:00:00Z"
assert comment.author == mock_author

def test_update_data_creates_new(self, mocker):
mocker.patch(
"apps.github.models.comment.Comment.objects.get", side_effect=Comment.DoesNotExist
)
mock_ct = ContentType(app_label="fake", model="fake")
mock_ct.id = 1
mocker.patch(
"django.contrib.contenttypes.models.ContentType.objects.get_for_model",
return_value=mock_ct,
)

gh_comment = mocker.Mock()
gh_comment.id = 12345
gh_comment.body = "New comment"

mock_save = mocker.patch.object(Comment, "save")

mock_content_object = mocker.Mock()
mock_content_object.pk = 999

comment = Comment.update_data(
gh_comment, author=None, content_object=mock_content_object, save=True
)

assert comment.github_id == 12345
assert comment.object_id == 999
assert comment.content_type == mock_ct
mock_save.assert_called_once()

def test_update_data_updates_existing(self, mocker):
existing_comment = Comment(github_id=12345, body="Old body")
mocker.patch(
"apps.github.models.comment.Comment.objects.get", return_value=existing_comment
)
mock_save = mocker.patch.object(Comment, "save")

gh_comment = mocker.Mock()
gh_comment.id = 12345
gh_comment.body = "Updated body"

comment = Comment.update_data(gh_comment, save=True)

assert comment.body == "Updated body"
assert comment.github_id == 12345
mock_save.assert_called_once()

def test_str_representation(self):
comment = Comment(body="A very long comment body that should be truncated", author=None)
long_body = "A" * 60
comment.body = long_body

assert str(comment).startswith("None - AAAAA")
assert len(str(comment)) <= 60
Loading