Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified .gitignore
Binary file not shown.
40 changes: 34 additions & 6 deletions backend/app/api/routes/chat.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,39 @@
import json
import asyncio
import logging
from fastapi import APIRouter, Request
from pydantic import BaseModel
from typing import List, Dict, Optional
from fastapi import APIRouter, Request, Depends
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse

from app.services.llm import get_english_translation, get_response_stream_async
from app.services.embedding import embedding_service
from app.services.database import get_client
from app.core.rate_limit import limiter

router = APIRouter()
logger = logging.getLogger(__name__)

class HistoryMessage(BaseModel):
role: str
content: str

class ChatRequest(BaseModel):
query: str
history: List[HistoryMessage] = Field(default_factory=list)

def _search_documents(query_vector):
return get_client().rpc(
'match_documents',
{'query_embedding': query_vector, 'match_count': 3}
).execute()

async def generate_chat_events(request: Request, query: str):
async def generate_chat_events(request: Request, query: str, history: List[HistoryMessage]):
"""
Generator function that streams SSE events.
It yields 'metadata' first, then chunks of 'content'.
"""
# 1. Translate Korean query to English
# 1. Translate Korean query to English // Note: We don't translate history here to save costs and reduce latency
try:
english_query = await asyncio.to_thread(get_english_translation, query)
except Exception:
Expand Down Expand Up @@ -79,8 +86,28 @@ async def generate_chat_events(request: Request, query: str):
# 6. Emit Event 2: content (Text chunk streaming via LLM)
combined_context = "\n\n".join(contexts)

MAX_HISTORY_MESSAGES = 20
MAX_HISTORY_CHARS = 1000
history_tail = (history or [])[-MAX_HISTORY_MESSAGES:]
formatted_parts: List[str] = []

for msg in history_tail:
role = str(msg.role or "").lower()
if role == "user":
role_name = "User"
elif role in ("ai", "agent", "philorag"):
role_name = "Agent (PhiloRAG)"
else:
continue
content = (msg.content or "").strip()
if not content:
continue
formatted_parts.append(f"{role_name}: {content[:MAX_HISTORY_CHARS]}")

formatted_history = "\n\n".join(formatted_parts)

try:
async for chunk in get_response_stream_async(context=combined_context, query=english_query):
async for chunk in get_response_stream_async(context=combined_context, query=english_query, history=formatted_history):
# If client disconnects, stop generating
if await request.is_disconnected():
break
Expand All @@ -94,8 +121,9 @@ async def generate_chat_events(request: Request, query: str):
return

@router.post("")
@limiter.limit("5/minute")
async def chat_endpoint(request: Request, chat_request: ChatRequest):
"""
Endpoint for accepting chat queries and returning a text/event-stream response.
"""
return EventSourceResponse(generate_chat_events(request, chat_request.query))
return EventSourceResponse(generate_chat_events(request, chat_request.query, chat_request.history))
12 changes: 12 additions & 0 deletions backend/app/core/rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.requests import Request

def get_real_client_ip(request: Request) -> str:
"""Get real client IP considering X-Forwarded-For header in a proxy environment."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return get_remote_address(request)

limiter = Limiter(key_func=get_real_client_ip)
8 changes: 8 additions & 0 deletions backend/app/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded

from app.api.routes import chat
from app.core.rate_limit import limiter

app = FastAPI(
title="PhiloRAG API",
description="Backend API for PhiloRAG chatbot system",
version="1.0.0",
)

# Register rate limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# Configure CORS
app.add_middleware(
CORSMiddleware,
Expand Down
57 changes: 36 additions & 21 deletions backend/app/services/llm.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
import threading
import google.generativeai as genai
from app.core.config import settings
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.output_parsers import StrOutputParser

if not settings.GEMINI_API_KEY:
raise RuntimeError("GEMINI_API_KEY must be configured")
# Models will be instantiated lazily or during function call
_llm = None
_llm_lock = threading.Lock()

# Configure Gemini API natively (optional, if native SDK features are needed)
genai.configure(api_key=settings.GEMINI_API_KEY)
def get_llm():
global _llm
if not settings.GEMINI_API_KEY:
raise RuntimeError("GEMINI_API_KEY must be configured")

if _llm is None:
with _llm_lock:
if _llm is None: # Double-checked locking
# Configure Gemini API natively (optional, if native SDK features are needed)
genai.configure(api_key=settings.GEMINI_API_KEY)

# Configure LangChain model
# TODO: model gemini-2.5-flash will be deprecated by June 17, 2026. Plan migration to gemini-3-flash.
_llm = ChatGoogleGenerativeAI(
model="gemini-3-flash",
google_api_key=settings.GEMINI_API_KEY,
temperature=0.7,
max_retries=2
)
return _llm
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# Configure LangChain model
# We use gemini-2.5-flash for faster and highly capable inference
llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
google_api_key=settings.GEMINI_API_KEY,
temperature=0.7,
max_retries=2
)

translation_prompt = PromptTemplate.from_template(
"""Translate the following user query from Korean to English.
Expand All @@ -31,42 +43,45 @@ def get_english_translation(korean_query: str) -> str:
"""
Translates a Korean query to English using Gemini via LangChain.
"""
chain = translation_prompt | llm | StrOutputParser()
chain = translation_prompt | get_llm() | StrOutputParser()
return chain.invoke({"query": korean_query})

def get_rag_prompt() -> PromptTemplate:
"""
Returns the core RAG prompt template taking English context and the translated query,
Returns the core RAG prompt template taking English context, history, and the translated query,
requesting the output in Korean.
"""
template = """
You are 'PhiloRAG', a philosophical chatbot providing wisdom and comfort based on Eastern and Western philosophies.
Use the following English philosophical context to answer the user's question.
Use the following English philosophical context and the chat history to answer the user's question.
Your final answer must be in Korean.

Context:
{context}

Recent Chat History:
{chat_history}

User Query (English translation):
{query}

Philosophical Prescription (in Korean):
"""
return PromptTemplate.from_template(template)

def get_response_stream(context: str, query: str):
def get_response_stream(context: str, query: str, history: str = ""):
"""
Returns a stream of strings from the LLM.
"""
prompt = get_rag_prompt()
chain = prompt | llm | StrOutputParser()
return chain.stream({"context": context, "query": query})
chain = prompt | get_llm() | StrOutputParser()
return chain.stream({"context": context, "chat_history": history, "query": query})

async def get_response_stream_async(context: str, query: str):
async def get_response_stream_async(context: str, query: str, history: str = ""):
"""
Returns an async stream of strings from the LLM.
"""
prompt = get_rag_prompt()
chain = prompt | llm | StrOutputParser()
async for chunk in chain.astream({"context": context, "query": query}):
chain = prompt | get_llm() | StrOutputParser()
async for chunk in chain.astream({"context": context, "chat_history": history, "query": query}):
yield chunk
26 changes: 0 additions & 26 deletions backend/pytest_log_utf8.txt

This file was deleted.

1 change: 1 addition & 0 deletions backend/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest-asyncio>=0.23.0
2 changes: 1 addition & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ pydantic-settings
python-dotenv
langchain-community==0.4.1
sentence-transformers>=2.2.0,<3.0.0
pytest-asyncio>=0.23.0
slowapi>=0.1.9
37 changes: 37 additions & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import pytest
import asyncio

@pytest.fixture(autouse=True)
def mock_env_vars():
"""Set dummy environment variables for tests to prevent import errors."""
os.environ["GEMINI_API_KEY"] = "dummy_test_key"
os.environ["SUPABASE_URL"] = "http://localhost:8000"
os.environ["SUPABASE_SERVICE_KEY"] = "dummy_service_key"

@pytest.fixture(scope="session")
def event_loop():
"""Create a single event loop instance shared across the entire test session.
Scoped to session to prevent global client implementations (e.g. Supabase Python client)
from binding `asyncio.locks.Event` instances to closed function-scoped event loops."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()

@pytest.fixture(autouse=True)
def reset_sse_starlette_appstatus():
"""Reset the global AppStatus event in sse-starlette to prevent event loop leakage between tests."""
import sse_starlette.sse
sse_starlette.sse.AppStatus.should_exit_event = None
yield

@pytest.fixture(autouse=True)
def reset_rate_limiter_storage():
"""Reset the rate limiter's storage between tests so that asyncio Locks are not shared across TestClient event loops."""
from app.core.rate_limit import limiter
from limits.storage import MemoryStorage

# limits 3.x MemoryStorage creates an asyncio.Lock internally.
# Resetting it forces re-bind to current test context.
limiter._storage = MemoryStorage()
yield
19 changes: 11 additions & 8 deletions backend/tests/e2e/test_chat_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@ def test_health_check():
assert response.json() == {"status": "healthy"}

@patch("app.api.routes.chat.embedding_service.generate_embedding")
@patch("app.api.routes.chat.supabase_client.rpc")
@patch("app.api.routes.chat._search_documents")
@patch("app.api.routes.chat.get_english_translation")
@patch("app.api.routes.chat.get_response_stream")
def test_chat_endpoint_success(mock_stream, mock_translate, mock_rpc, mock_embed):
@patch("app.api.routes.chat.get_response_stream_async")
def test_chat_endpoint_success(mock_stream, mock_translate, mock_search, mock_embed):
# Setup mocks
mock_translate.return_value = "What is life?"
mock_embed.return_value = [0.1] * 384

# Mock supabase response
mock_execute = MagicMock()
mock_execute.execute.return_value.data = [
# Mock _search_documents response
mock_response = MagicMock()
mock_response.data = [
{"content": "Life is suffering", "metadata": {"author": "Schopenhauer"}}
]
mock_rpc.return_value = mock_execute
mock_search.return_value = mock_response

# Mock LLM stream generator
mock_stream.return_value = (chunk for chunk in ["인생은", " ", "고통입니다."])
async def mock_async_generator(*_args, **_kwargs):
for chunk in ["인생은", " ", "고통입니다."]:
yield chunk
mock_stream.return_value = mock_async_generator()

# Call the actual endpoint with Fastapi test client
# Since it's SSE, we stream the response
Expand Down
30 changes: 30 additions & 0 deletions backend/tests/integration/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from fastapi.testclient import TestClient
import pytest
from app.main import app
from unittest.mock import patch

@patch("app.api.routes.chat.generate_chat_events")
def test_chat_rate_limiting(mock_events):
"""Test that the chat endpoint correctly limits requests to 5 per minute."""

# Mock event generator to bypass ML initializations and remote calls
async def mock_generator(*_args, **_kwargs):
yield "data: ok\n\n"
def _mock_events_factory(*_args, **_kwargs):
return mock_generator()
mock_events.side_effect = _mock_events_factory

# We will use the synchronous TestClient to avoid asyncio event loop leaking from SSE Streams
with TestClient(app) as client:
# Define a unique client IP for this test to avoid interfering with other tests
headers = {"X-Forwarded-For": "192.168.1.100"}

# Send 5 requests which should succeed
for _ in range(5):
# We use stream matching to verify successful request dispatching
with client.stream("POST", "/api/v1/chat", json={"query": "Test"}, headers=headers) as response:
assert response.status_code == 200

# The 6th request should fail with 429 Too Many Requests
response = client.post("/api/v1/chat", json={"query": "Test"}, headers=headers)
assert response.status_code == 429
Loading