diff --git a/.gitignore b/.gitignore index ccffb05..d8d77c5 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/backend/app/api/routes/chat.py b/backend/app/api/routes/chat.py index 1d67122..f200b95 100644 --- a/backend/app/api/routes/chat.py +++ b/backend/app/api/routes/chat.py @@ -1,19 +1,26 @@ import json import asyncio import logging -from fastapi import APIRouter, Request -from pydantic import BaseModel +from typing import List, Dict, Optional +from fastapi import APIRouter, Request, Depends +from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse from app.services.llm import get_english_translation, get_response_stream_async from app.services.embedding import embedding_service from app.services.database import get_client +from app.core.rate_limit import limiter router = APIRouter() logger = logging.getLogger(__name__) +class HistoryMessage(BaseModel): + role: str + content: str + class ChatRequest(BaseModel): query: str + history: List[HistoryMessage] = Field(default_factory=list) def _search_documents(query_vector): return get_client().rpc( @@ -21,12 +28,12 @@ def _search_documents(query_vector): {'query_embedding': query_vector, 'match_count': 3} ).execute() -async def generate_chat_events(request: Request, query: str): +async def generate_chat_events(request: Request, query: str, history: List[HistoryMessage]): """ Generator function that streams SSE events. It yields 'metadata' first, then chunks of 'content'. """ - # 1. Translate Korean query to English + # 1. Translate Korean query to English // Note: We don't translate history here to save costs and reduce latency try: english_query = await asyncio.to_thread(get_english_translation, query) except Exception: @@ -79,8 +86,28 @@ async def generate_chat_events(request: Request, query: str): # 6. Emit Event 2: content (Text chunk streaming via LLM) combined_context = "\n\n".join(contexts) + MAX_HISTORY_MESSAGES = 20 + MAX_HISTORY_CHARS = 1000 + history_tail = (history or [])[-MAX_HISTORY_MESSAGES:] + formatted_parts: List[str] = [] + + for msg in history_tail: + role = str(msg.role or "").lower() + if role == "user": + role_name = "User" + elif role in ("ai", "agent", "philorag"): + role_name = "Agent (PhiloRAG)" + else: + continue + content = (msg.content or "").strip() + if not content: + continue + formatted_parts.append(f"{role_name}: {content[:MAX_HISTORY_CHARS]}") + + formatted_history = "\n\n".join(formatted_parts) + try: - async for chunk in get_response_stream_async(context=combined_context, query=english_query): + async for chunk in get_response_stream_async(context=combined_context, query=english_query, history=formatted_history): # If client disconnects, stop generating if await request.is_disconnected(): break @@ -94,8 +121,9 @@ async def generate_chat_events(request: Request, query: str): return @router.post("") +@limiter.limit("5/minute") async def chat_endpoint(request: Request, chat_request: ChatRequest): """ Endpoint for accepting chat queries and returning a text/event-stream response. """ - return EventSourceResponse(generate_chat_events(request, chat_request.query)) + return EventSourceResponse(generate_chat_events(request, chat_request.query, chat_request.history)) diff --git a/backend/app/core/rate_limit.py b/backend/app/core/rate_limit.py new file mode 100644 index 0000000..da73772 --- /dev/null +++ b/backend/app/core/rate_limit.py @@ -0,0 +1,12 @@ +from slowapi import Limiter +from slowapi.util import get_remote_address +from starlette.requests import Request + +def get_real_client_ip(request: Request) -> str: + """Get real client IP considering X-Forwarded-For header in a proxy environment.""" + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + return get_remote_address(request) + +limiter = Limiter(key_func=get_real_client_ip) diff --git a/backend/app/main.py b/backend/app/main.py index d07c4db..c60a307 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,6 +1,10 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from slowapi import _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded + from app.api.routes import chat +from app.core.rate_limit import limiter app = FastAPI( title="PhiloRAG API", @@ -8,6 +12,10 @@ version="1.0.0", ) +# Register rate limiter +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + # Configure CORS app.add_middleware( CORSMiddleware, diff --git a/backend/app/services/llm.py b/backend/app/services/llm.py index e7fb641..b97af39 100644 --- a/backend/app/services/llm.py +++ b/backend/app/services/llm.py @@ -1,23 +1,35 @@ +import threading import google.generativeai as genai from app.core.config import settings from langchain_core.prompts import PromptTemplate from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.output_parsers import StrOutputParser -if not settings.GEMINI_API_KEY: - raise RuntimeError("GEMINI_API_KEY must be configured") +# Models will be instantiated lazily or during function call +_llm = None +_llm_lock = threading.Lock() -# Configure Gemini API natively (optional, if native SDK features are needed) -genai.configure(api_key=settings.GEMINI_API_KEY) +def get_llm(): + global _llm + if not settings.GEMINI_API_KEY: + raise RuntimeError("GEMINI_API_KEY must be configured") + + if _llm is None: + with _llm_lock: + if _llm is None: # Double-checked locking + # Configure Gemini API natively (optional, if native SDK features are needed) + genai.configure(api_key=settings.GEMINI_API_KEY) + + # Configure LangChain model + # TODO: model gemini-2.5-flash will be deprecated by June 17, 2026. Plan migration to gemini-3-flash. + _llm = ChatGoogleGenerativeAI( + model="gemini-3-flash", + google_api_key=settings.GEMINI_API_KEY, + temperature=0.7, + max_retries=2 + ) + return _llm -# Configure LangChain model -# We use gemini-2.5-flash for faster and highly capable inference -llm = ChatGoogleGenerativeAI( - model="gemini-2.5-flash", - google_api_key=settings.GEMINI_API_KEY, - temperature=0.7, - max_retries=2 -) translation_prompt = PromptTemplate.from_template( """Translate the following user query from Korean to English. @@ -31,22 +43,25 @@ def get_english_translation(korean_query: str) -> str: """ Translates a Korean query to English using Gemini via LangChain. """ - chain = translation_prompt | llm | StrOutputParser() + chain = translation_prompt | get_llm() | StrOutputParser() return chain.invoke({"query": korean_query}) def get_rag_prompt() -> PromptTemplate: """ - Returns the core RAG prompt template taking English context and the translated query, + Returns the core RAG prompt template taking English context, history, and the translated query, requesting the output in Korean. """ template = """ You are 'PhiloRAG', a philosophical chatbot providing wisdom and comfort based on Eastern and Western philosophies. - Use the following English philosophical context to answer the user's question. + Use the following English philosophical context and the chat history to answer the user's question. Your final answer must be in Korean. Context: {context} + Recent Chat History: + {chat_history} + User Query (English translation): {query} @@ -54,19 +69,19 @@ def get_rag_prompt() -> PromptTemplate: """ return PromptTemplate.from_template(template) -def get_response_stream(context: str, query: str): +def get_response_stream(context: str, query: str, history: str = ""): """ Returns a stream of strings from the LLM. """ prompt = get_rag_prompt() - chain = prompt | llm | StrOutputParser() - return chain.stream({"context": context, "query": query}) + chain = prompt | get_llm() | StrOutputParser() + return chain.stream({"context": context, "chat_history": history, "query": query}) -async def get_response_stream_async(context: str, query: str): +async def get_response_stream_async(context: str, query: str, history: str = ""): """ Returns an async stream of strings from the LLM. """ prompt = get_rag_prompt() - chain = prompt | llm | StrOutputParser() - async for chunk in chain.astream({"context": context, "query": query}): + chain = prompt | get_llm() | StrOutputParser() + async for chunk in chain.astream({"context": context, "chat_history": history, "query": query}): yield chunk diff --git a/backend/pytest_log_utf8.txt b/backend/pytest_log_utf8.txt deleted file mode 100644 index b7d8a76..0000000 --- a/backend/pytest_log_utf8.txt +++ /dev/null @@ -1,26 +0,0 @@ -============================= test session starts ============================= -platform win32 -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0 -rootdir: C:\Users\ysn65\Desktop\antigravity\philo-rag\backend -plugins: anyio-4.12.1, asyncio-1.3.0, cov-7.0.0 -asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function -collected 0 items / 1 error - -=================================== ERRORS ==================================== -___________________ ERROR collecting tests/unit/test_llm.py ___________________ -tests\unit\test_llm.py:12: in - from app.services.llm import get_english_translation, get_response_stream, get_response_stream_async -app\services\llm.py:8: in - raise RuntimeError("GEMINI_API_KEY must be configured") -E RuntimeError: GEMINI_API_KEY must be configured -============================== warnings summary =============================== -:488 - :488: DeprecationWarning: Type google._upb._message.MessageMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14. - -:488 - :488: DeprecationWarning: Type google._upb._message.ScalarMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14. - --- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -=========================== short test summary info =========================== -ERROR tests/unit/test_llm.py - RuntimeError: GEMINI_API_KEY must be configured -!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!! -======================== 2 warnings, 1 error in 4.52s ========================= diff --git a/backend/requirements-dev.txt b/backend/requirements-dev.txt new file mode 100644 index 0000000..5fcdfdb --- /dev/null +++ b/backend/requirements-dev.txt @@ -0,0 +1 @@ +pytest-asyncio>=0.23.0 diff --git a/backend/requirements.txt b/backend/requirements.txt index 04a35dd..b82dfb8 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,4 +11,4 @@ pydantic-settings python-dotenv langchain-community==0.4.1 sentence-transformers>=2.2.0,<3.0.0 -pytest-asyncio>=0.23.0 +slowapi>=0.1.9 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..bbd790c --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,37 @@ +import os +import pytest +import asyncio + +@pytest.fixture(autouse=True) +def mock_env_vars(): + """Set dummy environment variables for tests to prevent import errors.""" + os.environ["GEMINI_API_KEY"] = "dummy_test_key" + os.environ["SUPABASE_URL"] = "http://localhost:8000" + os.environ["SUPABASE_SERVICE_KEY"] = "dummy_service_key" + +@pytest.fixture(scope="session") +def event_loop(): + """Create a single event loop instance shared across the entire test session. + Scoped to session to prevent global client implementations (e.g. Supabase Python client) + from binding `asyncio.locks.Event` instances to closed function-scoped event loops.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + +@pytest.fixture(autouse=True) +def reset_sse_starlette_appstatus(): + """Reset the global AppStatus event in sse-starlette to prevent event loop leakage between tests.""" + import sse_starlette.sse + sse_starlette.sse.AppStatus.should_exit_event = None + yield + +@pytest.fixture(autouse=True) +def reset_rate_limiter_storage(): + """Reset the rate limiter's storage between tests so that asyncio Locks are not shared across TestClient event loops.""" + from app.core.rate_limit import limiter + from limits.storage import MemoryStorage + + # limits 3.x MemoryStorage creates an asyncio.Lock internally. + # Resetting it forces re-bind to current test context. + limiter._storage = MemoryStorage() + yield diff --git a/backend/tests/e2e/test_chat_endpoint.py b/backend/tests/e2e/test_chat_endpoint.py index f3a20c0..eb1771a 100644 --- a/backend/tests/e2e/test_chat_endpoint.py +++ b/backend/tests/e2e/test_chat_endpoint.py @@ -11,23 +11,26 @@ def test_health_check(): assert response.json() == {"status": "healthy"} @patch("app.api.routes.chat.embedding_service.generate_embedding") -@patch("app.api.routes.chat.supabase_client.rpc") +@patch("app.api.routes.chat._search_documents") @patch("app.api.routes.chat.get_english_translation") -@patch("app.api.routes.chat.get_response_stream") -def test_chat_endpoint_success(mock_stream, mock_translate, mock_rpc, mock_embed): +@patch("app.api.routes.chat.get_response_stream_async") +def test_chat_endpoint_success(mock_stream, mock_translate, mock_search, mock_embed): # Setup mocks mock_translate.return_value = "What is life?" mock_embed.return_value = [0.1] * 384 - # Mock supabase response - mock_execute = MagicMock() - mock_execute.execute.return_value.data = [ + # Mock _search_documents response + mock_response = MagicMock() + mock_response.data = [ {"content": "Life is suffering", "metadata": {"author": "Schopenhauer"}} ] - mock_rpc.return_value = mock_execute + mock_search.return_value = mock_response # Mock LLM stream generator - mock_stream.return_value = (chunk for chunk in ["인생은", " ", "고통입니다."]) + async def mock_async_generator(*_args, **_kwargs): + for chunk in ["인생은", " ", "고통입니다."]: + yield chunk + mock_stream.return_value = mock_async_generator() # Call the actual endpoint with Fastapi test client # Since it's SSE, we stream the response diff --git a/backend/tests/integration/test_chat.py b/backend/tests/integration/test_chat.py new file mode 100644 index 0000000..b7ac854 --- /dev/null +++ b/backend/tests/integration/test_chat.py @@ -0,0 +1,30 @@ +from fastapi.testclient import TestClient +import pytest +from app.main import app +from unittest.mock import patch + +@patch("app.api.routes.chat.generate_chat_events") +def test_chat_rate_limiting(mock_events): + """Test that the chat endpoint correctly limits requests to 5 per minute.""" + + # Mock event generator to bypass ML initializations and remote calls + async def mock_generator(*_args, **_kwargs): + yield "data: ok\n\n" + def _mock_events_factory(*_args, **_kwargs): + return mock_generator() + mock_events.side_effect = _mock_events_factory + + # We will use the synchronous TestClient to avoid asyncio event loop leaking from SSE Streams + with TestClient(app) as client: + # Define a unique client IP for this test to avoid interfering with other tests + headers = {"X-Forwarded-For": "192.168.1.100"} + + # Send 5 requests which should succeed + for _ in range(5): + # We use stream matching to verify successful request dispatching + with client.stream("POST", "/api/v1/chat", json={"query": "Test"}, headers=headers) as response: + assert response.status_code == 200 + + # The 6th request should fail with 429 Too Many Requests + response = client.post("/api/v1/chat", json={"query": "Test"}, headers=headers) + assert response.status_code == 429 diff --git a/backend/tests/integration/test_supabase_match.py b/backend/tests/integration/test_supabase_match.py index 8148c64..82cb428 100644 --- a/backend/tests/integration/test_supabase_match.py +++ b/backend/tests/integration/test_supabase_match.py @@ -10,29 +10,32 @@ async def test_supabase_match_integration(): # 1. We mock the embedding service to return a dummy vector with patch("app.api.routes.chat.embedding_service.generate_embedding") as mock_embed, \ - patch("app.api.routes.chat.supabase_client.rpc") as mock_rpc, \ + patch("app.api.routes.chat._search_documents") as mock_search, \ patch("app.api.routes.chat.get_english_translation") as mock_translate, \ - patch("app.api.routes.chat.get_response_stream") as mock_stream: + patch("app.api.routes.chat.get_response_stream_async") as mock_stream: mock_translate.return_value = "English Question" mock_embed.return_value = [0.1, 0.2, 0.3] - # Mock Supabase RPC response chain: .rpc().execute().data - mock_execute = MagicMock() - mock_execute.execute.return_value.data = [ + # Mock Supabase RPC response chain: _search_documents + mock_response = MagicMock() + mock_response.data = [ {"content": "Philosophy is life", "metadata": {"author": "Socrates"}} ] - mock_rpc.return_value = mock_execute + mock_search.return_value = mock_response # Mock LLM stream - mock_stream.return_value = (chunk for chunk in ["답변", "입니다"]) + async def mock_async_generator(*args, **kwargs): + for chunk in ["답변", "입니다"]: + yield chunk + mock_stream.return_value = mock_async_generator() # We need a mock request for the SSE loop from unittest.mock import AsyncMock mock_request = MagicMock() mock_request.is_disconnected = AsyncMock(return_value=False) - generator = generate_chat_events(mock_request, "안녕") + generator = generate_chat_events(mock_request, "안녕", []) events = [] async for event in generator: @@ -43,10 +46,7 @@ async def test_supabase_match_integration(): mock_embed.assert_called_once_with("English Question") # Important: Verify Supabase RPC was called with the exact vector from Embedding Service - mock_rpc.assert_called_once_with( - 'match_documents', - {'query_embedding': [0.1, 0.2, 0.3], 'match_count': 3} - ) + mock_search.assert_called_once_with([0.1, 0.2, 0.3]) # Verify event stream structure assert len(events) == 3 # metadata + "답변" + "입니다" diff --git a/backend/tests/unit/test_chat_models.py b/backend/tests/unit/test_chat_models.py new file mode 100644 index 0000000..4a981c7 --- /dev/null +++ b/backend/tests/unit/test_chat_models.py @@ -0,0 +1,26 @@ +import pytest +from app.api.routes.chat import ChatRequest + +def test_chat_request_accepts_history(): + """Test that ChatRequest model correctly parses history.""" + payload = { + "query": "What is virtue?", + "history": [ + {"role": "user", "content": "Hello"}, + {"role": "ai", "content": "How can I help you today?"} + ] + } + + request = ChatRequest(**payload) + assert request.query == "What is virtue?" + assert len(request.history) == 2 + assert request.history[0]["role"] == "user" + assert request.history[1]["content"] == "How can I help you today?" + +def test_chat_request_empty_history_default(): + """Test that ChatRequest model defaults to empty history.""" + payload = {"query": "What is virtue?"} + + request = ChatRequest(**payload) + assert request.query == "What is virtue?" + assert request.history == [] diff --git a/backend/tests/unit/test_llm.py b/backend/tests/unit/test_llm.py index 062ef09..2bef86c 100644 --- a/backend/tests/unit/test_llm.py +++ b/backend/tests/unit/test_llm.py @@ -31,9 +31,11 @@ def test_translation(setup_test_env): print("Testing translation...") from app.services.llm import get_english_translation with patch("app.services.llm.translation_prompt") as mock_prompt, \ - patch("app.services.llm.llm") as _mock_llm, \ + patch("app.services.llm.get_llm") as mock_get_llm, \ patch("app.services.llm.StrOutputParser") as _mock_parser: + _mock_llm = MagicMock() + mock_get_llm.return_value = _mock_llm mock_chain = MagicMock() mock_chain.invoke.return_value = "Translated Text" mock_chain.__or__.return_value = mock_chain @@ -47,9 +49,11 @@ def test_streaming(setup_test_env): print("Testing streaming...") from app.services.llm import get_response_stream with patch("app.services.llm.get_rag_prompt") as mock_prompt, \ - patch("app.services.llm.llm") as _mock_llm, \ + patch("app.services.llm.get_llm") as mock_get_llm, \ patch("app.services.llm.StrOutputParser") as _mock_parser: + _mock_llm = MagicMock() + mock_get_llm.return_value = _mock_llm mock_chain = MagicMock() mock_chain.stream.return_value = ["안녕하세요", " ", "철학자", "입니다."] mock_chain.__or__.return_value = mock_chain @@ -64,9 +68,11 @@ async def test_streaming_async(setup_test_env): print("Testing streaming async...") from app.services.llm import get_response_stream_async with patch("app.services.llm.get_rag_prompt") as mock_prompt, \ - patch("app.services.llm.llm") as _mock_llm, \ + patch("app.services.llm.get_llm") as mock_get_llm, \ patch("app.services.llm.StrOutputParser") as _mock_parser: + _mock_llm = MagicMock() + mock_get_llm.return_value = _mock_llm mock_chain = MagicMock() async def mock_astream(*_args, **_kwargs): for chunk in ["안녕하세요", " ", "철학자", "입니다."]: @@ -81,7 +87,5 @@ async def mock_astream(*_args, **_kwargs): # For manual execution if __name__ == "__main__": - import asyncio - test_translation() - test_streaming() - asyncio.run(test_streaming_async()) + import pytest + raise SystemExit(pytest.main([__file__])) diff --git a/frontend/app/page.tsx b/frontend/app/page.tsx index 40ffeeb..a12073c 100644 --- a/frontend/app/page.tsx +++ b/frontend/app/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState } from "react"; +import { useState, useCallback } from "react"; import { Sidebar } from "../components/sidebar/Sidebar"; import { ChatMain } from "../components/chat/ChatMain"; import { Message } from "../types/chat"; @@ -9,6 +9,38 @@ export default function Home() { const [messages, setMessages] = useState([]); const [isSubmitting, setIsSubmitting] = useState(false); + const processLine = useCallback((line: string, eventObj: { current: string }, aiMsgId: string): boolean => { + if (line.startsWith("event: ")) { + eventObj.current = line.substring(7).trim(); + } else if (line.startsWith("data: ")) { + const currentData = line.substring(6).replace(/\r$/, ""); + const currentEvent = eventObj.current; + + if (currentEvent === "metadata" && currentData.trim() !== "") { + try { + const metaJson = JSON.parse(currentData); + const philosophersArray = Array.isArray(metaJson.philosophers) ? metaJson.philosophers : []; + setMessages((prev) => + prev.map(msg => msg.id === aiMsgId ? { ...msg, metadata: philosophersArray } : msg) + ); + } catch { console.error("Could not parse metadata event:", currentData) } + } else if (currentEvent === "content") { + // un-escape \\n to real newlines + const char = currentData.replace(/\\n/g, '\n'); + setMessages((prev) => + prev.map(msg => msg.id === aiMsgId ? { ...msg, content: msg.content + char } : msg) + ); + } else if (currentEvent === "error") { + console.error("Chat error:", currentData); + setMessages((prev) => + prev.map(msg => msg.id === aiMsgId ? { ...msg, content: currentData, isStreaming: false } : msg) + ); + return true; + } + } + return false; + }, []); + const handleSendMessage = async (query: string) => { if (!query.trim() || isSubmitting) return; @@ -35,50 +67,32 @@ export default function Home() { setIsSubmitting(true); try { + const historyToSend = messages.slice(-10).map(msg => ({ + role: msg.role, + content: msg.content + })); + const baseUrl = process.env.NEXT_PUBLIC_API_BASE_URL || "http://localhost:8000"; const res = await fetch(`${baseUrl}/api/v1/chat`, { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ query: query }) + body: JSON.stringify({ query: query, history: historyToSend }) }); + if (res.status === 429) { + setMessages((prev) => + prev.map(msg => msg.id === aiMsgId ? { ...msg, isStreaming: false, content: "입력량이 너무 많습니다. 잠시 후 1분 뒤에 다시 철학자와 대화를 시도해 주세요." } : msg) + ); + return; + } + if (!res.ok) throw new Error(`Failed to fetch: ${res.status} ${res.statusText}`); const reader = res.body?.getReader(); const decoder = new TextDecoder(); if (!reader) throw new Error("No reader"); - const processLine = (line: string, eventObj: { current: string }): boolean => { - if (line.startsWith("event: ")) { - eventObj.current = line.substring(7).trim(); - } else if (line.startsWith("data: ")) { - const currentData = line.substring(6); - const currentEvent = eventObj.current; - - if (currentEvent === "metadata" && currentData.trim() !== "") { - try { - const metaJson = JSON.parse(currentData); - const philosophersArray = Array.isArray(metaJson.philosophers) ? metaJson.philosophers : []; - setMessages((prev) => - prev.map(msg => msg.id === aiMsgId ? { ...msg, metadata: philosophersArray } : msg) - ); - } catch { console.error("Could not parse metadata event:", currentData) } - } else if (currentEvent === "content") { - // un-escape \\n to real newlines - const char = currentData.replace(/\\n/g, '\n'); - setMessages((prev) => - prev.map(msg => msg.id === aiMsgId ? { ...msg, content: msg.content + char } : msg) - ); - } else if (currentEvent === "error") { - console.error("Chat error:", currentData); - setMessages((prev) => - prev.map(msg => msg.id === aiMsgId ? { ...msg, content: currentData, isStreaming: false } : msg) - ); - return true; - } - } - return false; - }; + // Process line memoized above const eventObj = { current: "" }; let buffer = ""; @@ -93,7 +107,7 @@ export default function Home() { if (buffer) { const lines = buffer.split('\n'); for (const line of lines) { - if (processLine(line, eventObj)) { + if (processLine(line, eventObj, aiMsgId)) { shouldStop = true; break; } @@ -109,7 +123,7 @@ export default function Home() { buffer = lines.pop() || ""; for (const line of lines) { - if (processLine(line, eventObj)) { + if (processLine(line, eventObj, aiMsgId)) { shouldStop = true; break; } diff --git a/frontend/components/chat/FloatingInput.tsx b/frontend/components/chat/FloatingInput.tsx index 176dc34..e458c62 100644 --- a/frontend/components/chat/FloatingInput.tsx +++ b/frontend/components/chat/FloatingInput.tsx @@ -48,7 +48,7 @@ export function FloatingInput({ onSendMessage, isSubmitting }: FloatingInputProp lastCompositionEndAt.current = Date.now(); }} onKeyDown={(e) => { - if (isComposing.current || Date.now() - lastCompositionEndAt.current < 50) return; + if (e.nativeEvent.isComposing || isComposing.current || Date.now() - lastCompositionEndAt.current < 50) return; if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); diff --git a/frontend/components/chat/MessageList.tsx b/frontend/components/chat/MessageList.tsx index fe173ba..82ee820 100644 --- a/frontend/components/chat/MessageList.tsx +++ b/frontend/components/chat/MessageList.tsx @@ -71,11 +71,28 @@ export function MessageList({ messages, onOpenCitation }: Props) { {/* Citation Cards if metadata exists */} - {msg.metadata && msg.metadata.length > 0 && Array.from(new Set(msg.metadata.map(m => m.book_info.title))).map((title, idx) => { - const meta = msg.metadata?.find(m => m.book_info.title === title); - if (!meta) return null; + {msg.metadata && msg.metadata.length > 0 && Array.from(new Map(msg.metadata.map((m) => [m.id, m])).values()).map((meta) => { + const title = meta.book_info.title; + const isClickable = Boolean(onOpenCitation); + const interactiveProps = isClickable + ? { + role: "button" as const, + tabIndex: 0, + onClick: () => onOpenCitation?.(meta), + onKeyDown: (e: React.KeyboardEvent) => { + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + onOpenCitation?.(meta); + } + }, + } + : {}; return ( -
+
{meta.book_info.cover_url && !meta.book_info.cover_url.includes("dummy") ? ( <> @@ -96,7 +113,7 @@ export function MessageList({ messages, onOpenCitation }: Props) { - -