diff --git a/api/ask_astro/services/questions.py b/api/ask_astro/services/questions.py index d3289d51..04d6de10 100644 --- a/api/ask_astro/services/questions.py +++ b/api/ask_astro/services/questions.py @@ -5,9 +5,6 @@ import time from logging import getLogger -from langchain import callbacks - -from ask_astro.chains.answer_question import answer_question_chain from ask_astro.clients.firestore import firestore_client from ask_astro.config import FirestoreCollections from ask_astro.models.request import AskAstroRequest, Source @@ -33,6 +30,10 @@ async def answer_question(request: AskAstroRequest) -> None: :param request: The request to answer the question. """ try: + from langchain import callbacks + + from ask_astro.chains.answer_question import answer_question_chain + # First, mark the request as in_progress and add it to the database request.status = "in_progress" await _update_firestore_request(request) diff --git a/tests/api/ask_astro/rest/controllers/test_list_recent_requests.py b/tests/api/ask_astro/rest/controllers/test_list_recent_requests.py index 9197479a..e3d2e978 100644 --- a/tests/api/ask_astro/rest/controllers/test_list_recent_requests.py +++ b/tests/api/ask_astro/rest/controllers/test_list_recent_requests.py @@ -33,7 +33,6 @@ def generate_mock_document(data): async def test_on_list_recent_requests(app, mock_data, expected_status, expected_response): with patch("ask_astro.config.FirestoreCollections.requests", new_callable=PropertyMock) as mock_collection: mock_collection.return_value = "mock_collection_name" - print("Setting up mock") with patch("google.cloud.firestore_v1.Client", new=AsyncMock()) as MockFirestoreClient: # Here, MockFirestoreClient will replace the actual Firestore Client everywhere in the code. mock_client_instance = MockFirestoreClient.return_value diff --git a/tests/api/ask_astro/rest/controllers/test_post_request.py b/tests/api/ask_astro/rest/controllers/test_post_request.py new file mode 100644 index 00000000..8d14eb43 --- /dev/null +++ b/tests/api/ask_astro/rest/controllers/test_post_request.py @@ -0,0 +1,50 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from sanic import Sanic +from sanic_testing import TestManager + +sanitized_name = __name__.replace(".", "_") + + +@pytest.fixture +def app(): + """Fixture to create a new Sanic application for testing the POST request handler.""" + app_instance = Sanic(sanitized_name) + TestManager(app_instance) + from ask_astro.rest.controllers import register_routes + + register_routes(app_instance) + return app_instance + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "request_payload,expected_status,expected_response", + [ + ({"prompt": "Tell me about space"}, 200, {"request_uuid": "test-uuid"}), + ({"prompt": "What is quantum mechanics?"}, 200, {"request_uuid": "test-uuid"}), + ({}, 400, {"error": "prompt is required"}), + ], +) +async def test_on_post_request(app, request_payload, expected_status, expected_response): + """Test the POST request endpoint behavior based on different input payloads.""" + with patch("ask_astro.services.questions.answer_question") as mock_answer_question, patch( + "ask_astro.clients.firestore.firestore.AsyncClient" + ) as mock_firestore: + with patch("google.cloud.firestore_v1.Client", new=AsyncMock()): + mock_firestore.collection.return_value.document.return_value.get.return_value = AsyncMock() + mock_answer_question.return_value = AsyncMock() + + request, response = await app.asgi_client.post("/requests", json=request_payload) + + assert response.status == expected_status + # If expecting a 200 status + if expected_status == 200: + assert "request_uuid" in response.json + assert isinstance(response.json.get("request_uuid"), str) + + # If expecting a 400 status + elif expected_status == 400: + assert "error" in response.json + assert response.json["error"] == expected_response["error"]