From f4c8197d0d02ecb5b63d68ba069898a37742b676 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 25 Aug 2025 10:07:33 +0200 Subject: [PATCH 01/75] feat: added gema 27b --- .../compose/docker-compose.gemma-27b-gpu.yml | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 docker/compose/docker-compose.gemma-27b-gpu.yml diff --git a/docker/compose/docker-compose.gemma-27b-gpu.yml b/docker/compose/docker-compose.gemma-27b-gpu.yml new file mode 100644 index 00000000..95a76721 --- /dev/null +++ b/docker/compose/docker-compose.gemma-27b-gpu.yml @@ -0,0 +1,46 @@ +services: + gemma_27b_gpu: + image: nillion/nilai-vllm:latest + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + ipc: host + ulimits: + memlock: -1 + stack: 67108864 + env_file: + - .env + restart: unless-stopped + depends_on: + etcd: + condition: service_healthy + command: > + --model google/gemma-3-27b-it + --gpu-memory-utilization 0.95 + --max-model-len 60000 + --max-num-batched-tokens 60000 + --tensor-parallel-size 1 + --enable-auto-tool-choice + --tool-call-parser llama3_json + --uvicorn-log-level warning + environment: + - SVC_HOST=gemma_27b_gpu + - SVC_PORT=8000 + - ETCD_HOST=etcd + - ETCD_PORT=2379 + - TOOL_SUPPORT=true + - MULTIMODAL_SUPPORT=true + volumes: + - hugging_face_models:/root/.cache/huggingface + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + retries: 3 + start_period: 60s + timeout: 10s +volumes: + hugging_face_models: \ No newline at end of file From 617a590b5953126a92a21033a55ffac02f6c148b Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 25 Aug 2025 11:34:54 +0200 Subject: [PATCH 02/75] feat: implement multimodal content support with image URL validation --- .../src/nilai_api/handlers/image_support.py | 82 ++++++++++++ nilai-api/src/nilai_api/routers/private.py | 12 ++ nilai-models/src/nilai_models/daemon.py | 1 + .../src/nilai_common/api_model.py | 23 +++- .../nilai-common/src/nilai_common/config.py | 2 + tests/unit/nilai_api/routers/test_private.py | 123 ++++++++++++++++++ 6 files changed, 238 insertions(+), 5 deletions(-) create mode 100644 nilai-api/src/nilai_api/handlers/image_support.py diff --git a/nilai-api/src/nilai_api/handlers/image_support.py b/nilai-api/src/nilai_api/handlers/image_support.py new file mode 100644 index 00000000..93f41923 --- /dev/null +++ b/nilai-api/src/nilai_api/handlers/image_support.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from typing import List, Optional, Any +from fastapi import HTTPException +from nilai_common import Message + +@dataclass(frozen=True) +class MultimodalCheck: + has_multimodal: bool + error: Optional[str] = None + + +def _extract_url(image_url_field: Any) -> Optional[str]: + """ + Support both object-with-attr and dict-like shapes. + Returns the URL string or None. + """ + if image_url_field is None: + return None + + url = getattr(image_url_field, "url", None) + if url is not None: + return url + if isinstance(image_url_field, dict): + return image_url_field.get("url") + return None + + +def multimodal_check(messages: List[Message]) -> MultimodalCheck: + """ + Single-pass check: + - detect if any part is type=='image_url' + - validate that image_url.url exists and is a base64 data URL + Returns: + MultimodalCheck(has_multimodal: bool, error: Optional[str]) + """ + has_mm = False + + for m in messages: + content = getattr(m, "content", None) or [] + for item in content: + if getattr(item, "type", None) == "image_url": + has_mm = True + iu = getattr(item, "image_url", None) + url = _extract_url(iu) + if not url: + return MultimodalCheck(True, "image_url.url is required for image_url parts") + if not (url.startswith("data:image/") and ";base64," in url): + return MultimodalCheck(True, "Only base64 data URLs are allowed for images (data:image/...;base64,...)") + + return MultimodalCheck(has_mm, None) + + +def has_multimodal_content(messages: List[Message], precomputed: Optional[MultimodalCheck] = None) -> bool: + """ + Check if any message contains multimodal content (image_url parts). + + Args: + messages: List of messages to check + precomputed: Optional precomputed result from multimodal_check() to avoid re-iterating + + Returns: + True if any message contains image_url parts, False otherwise + """ + res = precomputed or multimodal_check(messages) + return res.has_multimodal + + +def validate_multimodal_content(messages: List[Message], precomputed: Optional[MultimodalCheck] = None) -> None: + """ + Validate that multimodal content (image_url parts) follows the required format. + + Args: + messages: List of messages to validate + precomputed: Optional precomputed result from multimodal_check() to avoid re-iterating + + Raises: + HTTPException(400): When image_url parts don't have required URL or use invalid format + (only base64 data URLs are allowed: data:image/...;base64,...) + """ + res = precomputed or multimodal_check(messages) + if res.error: + raise HTTPException(status_code=400, detail=res.error) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index dfaa34c7..6c3bb6a6 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -6,6 +6,7 @@ from nilai_api.attestation import get_attestation_report from nilai_api.handlers.nilrag import handle_nilrag from nilai_api.handlers.web_search import handle_web_search +from nilai_api.handlers.image_support import multimodal_check from fastapi import APIRouter, Body, Depends, HTTPException, status, Request from fastapi.responses import StreamingResponse @@ -211,6 +212,17 @@ async def chat_completion( status_code=400, detail="Model does not support tool usage, remove tools from request", ) + + multimodal_result = multimodal_check(req.messages) + if multimodal_result.has_multimodal: + if not endpoint.metadata.multimodal_support: + raise HTTPException( + status_code=400, + detail="Model does not support multimodal content, remove image inputs from request", + ) + if multimodal_result.error: + raise HTTPException(status_code=400, detail=multimodal_result.error) + model_url = endpoint.url + "/v1/" logger.info( diff --git a/nilai-models/src/nilai_models/daemon.py b/nilai-models/src/nilai_models/daemon.py index 7402f3f3..bab3ad42 100644 --- a/nilai-models/src/nilai_models/daemon.py +++ b/nilai-models/src/nilai_models/daemon.py @@ -38,6 +38,7 @@ async def get_metadata(num_retries=30): source=f"https://huggingface.co/{model_name}", # Model source supported_features=["chat_completion"], # Capabilities tool_support=SETTINGS.tool_support, # Tool support + multimodal_support=SETTINGS.multimodal_support, # Multimodal support ) except Exception as e: diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index a3ae1e81..94880219 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -1,15 +1,27 @@ import uuid -from typing import Annotated, Iterable, List, Literal, Optional +from typing import Annotated, Iterable, List, Literal, Optional, Union -from openai.types.chat import ChatCompletion, ChatCompletionMessage -from openai.types.chat.chat_completion import Choice as OpenaAIChoice +from openai.types.chat import ChatCompletion from openai.types.chat import ChatCompletionToolParam +from openai.types.chat.chat_completion import Choice as OpenaAIChoice from pydantic import BaseModel, Field -class Message(ChatCompletionMessage): - role: Literal["system", "user", "assistant", "tool"] # type: ignore +class ImageUrl(BaseModel): + url: str + detail: Optional[str] = "auto" + + +class MessageContentItem(BaseModel): + type: Literal["text", "image_url"] + text: Optional[str] = None + image_url: Optional[ImageUrl] = None + + +class Message(BaseModel): + role: Literal["system", "user", "assistant", "tool"] + content: Union[str, List[MessageContentItem]] class Choice(OpenaAIChoice): @@ -71,6 +83,7 @@ class ModelMetadata(BaseModel): source: str supported_features: List[str] tool_support: bool + multimodal_support: bool = False class ModelEndpoint(BaseModel): diff --git a/packages/nilai-common/src/nilai_common/config.py b/packages/nilai-common/src/nilai_common/config.py index aad86b01..cc0e6bdc 100644 --- a/packages/nilai-common/src/nilai_common/config.py +++ b/packages/nilai-common/src/nilai_common/config.py @@ -8,6 +8,7 @@ class HostSettings(BaseModel): etcd_host: str = "localhost" etcd_port: int = 2379 tool_support: bool = False + multimodal_support: bool = False gunicorn_workers: int = 10 attestation_host: str = "localhost" attestation_port: int = 8081 @@ -19,6 +20,7 @@ class HostSettings(BaseModel): etcd_host=str(os.getenv("ETCD_HOST", "localhost")), etcd_port=int(os.getenv("ETCD_PORT", 2379)), tool_support=bool(os.getenv("TOOL_SUPPORT", False)), + multimodal_support=bool(os.getenv("MULTIMODAL_SUPPORT", False)), gunicorn_workers=int(os.getenv("NILAI_GUNICORN_WORKERS", 10)), attestation_host=str(os.getenv("ATTESTATION_HOST", "localhost")), attestation_port=int(os.getenv("ATTESTATION_PORT", 8081)), diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 183d5cdb..6c98f777 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -206,3 +206,126 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien "completion_tokens_details": None, "prompt_tokens_details": None, } + + +def test_chat_completion_with_image_support(mock_user, mock_user_manager, mocker, client): + mocker.patch("openai.api_key", new="test-api-key") + from openai.types.chat import ChatCompletion + from nilai_common import ModelMetadata, ModelEndpoint + + data = RESPONSE.model_dump() + data.pop("signature") + data.pop("sources", None) + response_data = ChatCompletion(**data) + + mock_chat_completions = MagicMock() + mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) + mock_chat = MagicMock() + mock_chat.completions = mock_chat_completions + mock_async_openai_instance = MagicMock() + mock_async_openai_instance.chat = mock_chat + mocker.patch( + "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance + ) + + multimodal_metadata = ModelMetadata( + id="meta-llama/Llama-3.2-1B-Instruct", + name="meta-llama/Llama-3.2-1B-Instruct", + version="1.0", + description="Multimodal model", + author="Meta", + license="Apache 2.0", + source="https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct", + supported_features=["chat_completion"], + tool_support=False, + multimodal_support=True, + ) + multimodal_endpoint = ModelEndpoint( + url="http://test-model-url", + metadata=multimodal_metadata + ) + + mocker.patch.object( + state, + "get_model", + return_value=multimodal_endpoint + ) + + response = client.post( + "/v1/chat/completions", + json={ + "model": "meta-llama/Llama-3.2-1B-Instruct", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + } + } + ] + } + ], + }, + headers={"Authorization": "Bearer test-api-key"}, + ) + assert response.status_code == 200 + assert "usage" in response.json() + + +def test_chat_completion_with_image_unsupported_model(mock_user, mock_user_manager, client): + response = client.post( + "/v1/chat/completions", + json={ + "model": "meta-llama/Llama-3.2-1B-Instruct", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + } + } + ] + } + ], + }, + headers={"Authorization": "Bearer test-api-key"}, + ) + assert response.status_code == 400 + assert "multimodal content" in response.json()["detail"] + + +def test_chat_completion_with_invalid_image_url(mock_user, mock_user_manager, client): + response = client.post( + "/v1/chat/completions", + json={ + "model": "meta-llama/Llama-3.2-1B-Instruct", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.jpg" + } + } + ] + } + ], + }, + headers={"Authorization": "Bearer test-api-key"}, + ) + assert response.status_code == 400 + assert "base64 data URLs" in response.json()["detail"] From 941f784fb01e59272c78b9b20251e6c16da33b5f Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 25 Aug 2025 19:37:48 +0200 Subject: [PATCH 03/75] refactor: added multimodal parameter + web search with image in query + image format validator --- .../src/nilai_api/handlers/image_support.py | 26 +++++--- .../src/nilai_api/handlers/web_search.py | 34 ++++++---- nilai-models/src/nilai_models/daemon.py | 11 +++- tests/unit/nilai_api/routers/test_private.py | 63 +++++++++---------- 4 files changed, 78 insertions(+), 56 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/image_support.py b/nilai-api/src/nilai_api/handlers/image_support.py index 93f41923..f353912c 100644 --- a/nilai-api/src/nilai_api/handlers/image_support.py +++ b/nilai-api/src/nilai_api/handlers/image_support.py @@ -3,6 +3,7 @@ from fastapi import HTTPException from nilai_common import Message + @dataclass(frozen=True) class MultimodalCheck: has_multimodal: bool @@ -43,21 +44,28 @@ def multimodal_check(messages: List[Message]) -> MultimodalCheck: iu = getattr(item, "image_url", None) url = _extract_url(iu) if not url: - return MultimodalCheck(True, "image_url.url is required for image_url parts") + return MultimodalCheck( + True, "image_url.url is required for image_url parts" + ) if not (url.startswith("data:image/") and ";base64," in url): - return MultimodalCheck(True, "Only base64 data URLs are allowed for images (data:image/...;base64,...)") + return MultimodalCheck( + True, + "Only base64 data URLs are allowed for images (data:image/...;base64,...)", + ) return MultimodalCheck(has_mm, None) -def has_multimodal_content(messages: List[Message], precomputed: Optional[MultimodalCheck] = None) -> bool: +def has_multimodal_content( + messages: List[Message], precomputed: Optional[MultimodalCheck] = None +) -> bool: """ Check if any message contains multimodal content (image_url parts). - + Args: messages: List of messages to check precomputed: Optional precomputed result from multimodal_check() to avoid re-iterating - + Returns: True if any message contains image_url parts, False otherwise """ @@ -65,14 +73,16 @@ def has_multimodal_content(messages: List[Message], precomputed: Optional[Multim return res.has_multimodal -def validate_multimodal_content(messages: List[Message], precomputed: Optional[MultimodalCheck] = None) -> None: +def validate_multimodal_content( + messages: List[Message], precomputed: Optional[MultimodalCheck] = None +) -> None: """ Validate that multimodal content (image_url parts) follows the required format. - + Args: messages: List of messages to validate precomputed: Optional precomputed result from multimodal_check() to avoid re-iterating - + Raises: HTTPException(400): When image_url parts don't have required URL or use invalid format (only base64 data URLs are allowed: data:image/...;base64,...) diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index bd92caeb..01313ec9 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -11,6 +11,7 @@ Source, WebSearchEnhancedMessages, WebSearchContext, + MessageContentItem, ) from nilai_common import Message @@ -152,22 +153,29 @@ async def perform_web_search_async(query: str) -> WebSearchContext: async def enhance_messages_with_web_search( messages: List[Message], query: str ) -> WebSearchEnhancedMessages: - """Enhance a list of messages with web search context. - - Args: - messages: List of conversation messages to enhance - query: Search query to retrieve web search results for - - Returns: - WebSearchEnhancedMessages containing the original messages with web search - context prepended as a system message, along with source information - """ ctx = await perform_web_search_async(query) - enhanced = [Message(role="system", content=ctx.prompt)] + messages query_source = Source(source="search_query", content=query) + + if not messages or messages[-1].role != "user": + return WebSearchEnhancedMessages( + messages=messages, sources=[query_source] + ctx.sources + ) + + web_search_context = f"\n\nWeb search results:\n{ctx.prompt}" + + last = messages[-1] + items = ( + [MessageContentItem(type="text", text=last.content)] + if isinstance(last.content, str) + else list(last.content) + ) + items.append(MessageContentItem(type="text", text=web_search_context)) + + enhanced_messages = list(messages) + enhanced_messages[-1] = Message(role="user", content=items) + return WebSearchEnhancedMessages( - messages=enhanced, - sources=[query_source] + ctx.sources, + messages=enhanced_messages, sources=[query_source] + ctx.sources ) diff --git a/nilai-models/src/nilai_models/daemon.py b/nilai-models/src/nilai_models/daemon.py index bab3ad42..b37a5497 100644 --- a/nilai-models/src/nilai_models/daemon.py +++ b/nilai-models/src/nilai_models/daemon.py @@ -28,7 +28,12 @@ async def get_metadata(num_retries=30): response.raise_for_status() response_data = response.json() model_name = response_data["data"][0]["id"] - return ModelMetadata( + + supported_features = ["chat_completion"] + if SETTINGS.multimodal_support: + supported_features.append("multimodal") + + metadata = ModelMetadata( id=model_name, # Unique identifier name=model_name, # Human-readable name version="1.0", # Model version @@ -36,11 +41,13 @@ async def get_metadata(num_retries=30): author="", # Model creators license="Apache 2.0", # Usage license source=f"https://huggingface.co/{model_name}", # Model source - supported_features=["chat_completion"], # Capabilities + supported_features=supported_features, # Capabilities tool_support=SETTINGS.tool_support, # Tool support multimodal_support=SETTINGS.multimodal_support, # Multimodal support ) + return metadata + except Exception as e: if not url: logger.warning(f"Failed to build url: {e}") diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 6c98f777..6a28757f 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -208,7 +208,9 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien } -def test_chat_completion_with_image_support(mock_user, mock_user_manager, mocker, client): +def test_chat_completion_with_image_support( + mock_user, mock_user_manager, mocker, client +): mocker.patch("openai.api_key", new="test-api-key") from openai.types.chat import ChatCompletion from nilai_common import ModelMetadata, ModelEndpoint @@ -217,7 +219,7 @@ def test_chat_completion_with_image_support(mock_user, mock_user_manager, mocker data.pop("signature") data.pop("sources", None) response_data = ChatCompletion(**data) - + mock_chat_completions = MagicMock() mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) mock_chat = MagicMock() @@ -227,13 +229,13 @@ def test_chat_completion_with_image_support(mock_user, mock_user_manager, mocker mocker.patch( "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance ) - + multimodal_metadata = ModelMetadata( - id="meta-llama/Llama-3.2-1B-Instruct", - name="meta-llama/Llama-3.2-1B-Instruct", + id="google/gemma-3-4b-it", + name="google/gemma-3-4b-it", version="1.0", description="Multimodal model", - author="Meta", + author="Google", license="Apache 2.0", source="https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct", supported_features=["chat_completion"], @@ -241,16 +243,11 @@ def test_chat_completion_with_image_support(mock_user, mock_user_manager, mocker multimodal_support=True, ) multimodal_endpoint = ModelEndpoint( - url="http://test-model-url", - metadata=multimodal_metadata - ) - - mocker.patch.object( - state, - "get_model", - return_value=multimodal_endpoint + url="http://test-model-url", metadata=multimodal_metadata ) - + + mocker.patch.object(state, "get_model", return_value=multimodal_endpoint) + response = client.post( "/v1/chat/completions", json={ @@ -258,17 +255,17 @@ def test_chat_completion_with_image_support(mock_user, mock_user_manager, mocker "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { - "role": "user", + "role": "user", "content": [ {"type": "text", "text": "What's in this image?"}, { "type": "image_url", "image_url": { "url": "" - } - } - ] - } + }, + }, + ], + }, ], }, headers={"Authorization": "Bearer test-api-key"}, @@ -277,7 +274,9 @@ def test_chat_completion_with_image_support(mock_user, mock_user_manager, mocker assert "usage" in response.json() -def test_chat_completion_with_image_unsupported_model(mock_user, mock_user_manager, client): +def test_chat_completion_with_image_unsupported_model( + mock_user, mock_user_manager, client +): response = client.post( "/v1/chat/completions", json={ @@ -285,17 +284,17 @@ def test_chat_completion_with_image_unsupported_model(mock_user, mock_user_manag "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { - "role": "user", + "role": "user", "content": [ {"type": "text", "text": "What's in this image?"}, { "type": "image_url", "image_url": { "url": "" - } - } - ] - } + }, + }, + ], + }, ], }, headers={"Authorization": "Bearer test-api-key"}, @@ -312,17 +311,15 @@ def test_chat_completion_with_invalid_image_url(mock_user, mock_user_manager, cl "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { - "role": "user", + "role": "user", "content": [ {"type": "text", "text": "What's in this image?"}, { "type": "image_url", - "image_url": { - "url": "https://example.com/image.jpg" - } - } - ] - } + "image_url": {"url": "https://example.com/image.jpg"}, + }, + ], + }, ], }, headers={"Authorization": "Bearer test-api-key"}, From a74acbc161585739bb7b3842153ad64a87c4e525 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 09:57:30 +0200 Subject: [PATCH 04/75] refactor: update chat completion message structure and change model to google/gemma-3-4b-it --- tests/unit/__init__.py | 4 ++-- tests/unit/nilai_api/routers/test_private.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 3fdd70a0..246d56ce 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,4 +1,4 @@ -from openai.types.chat.chat_completion import ChoiceLogprobs +from openai.types.chat.chat_completion import ChoiceLogprobs, ChatCompletionMessage from nilai_common import ( SignedChatCompletion, @@ -33,7 +33,7 @@ choices=[ Choice( index=0, - message=Message(role="assistant", content="test-content"), + message=ChatCompletionMessage(role="assistant", content="test-content"), finish_reason="stop", logprobs=ChoiceLogprobs(), ) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 6a28757f..b19ee41c 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -251,7 +251,7 @@ def test_chat_completion_with_image_support( response = client.post( "/v1/chat/completions", json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", + "model": "google/gemma-3-4b-it", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { @@ -280,7 +280,7 @@ def test_chat_completion_with_image_unsupported_model( response = client.post( "/v1/chat/completions", json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", + "model": "google/gemma-3-4b-it", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { @@ -307,7 +307,7 @@ def test_chat_completion_with_invalid_image_url(mock_user, mock_user_manager, cl response = client.post( "/v1/chat/completions", json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", + "model": "google/gemma-3-4b-it", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { From 7fe8325eb10dbfbc90fa0de8b233f9ef4ac1a212 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 10:21:53 +0200 Subject: [PATCH 05/75] feat: add Docker Compose configuration for gemma-4b in ci pipeline for tests + added multimodal e2e tests --- .../docker-compose.gemma-4b-gpu.ci.yml | 51 +++++++++++++++++++ tests/e2e/config.py | 1 + tests/e2e/test_openai.py | 48 +++++++++++++++++ 3 files changed, 100 insertions(+) create mode 100644 docker/compose/docker-compose.gemma-4b-gpu.ci.yml diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml new file mode 100644 index 00000000..81203986 --- /dev/null +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -0,0 +1,51 @@ +services: + gemma_4b_gpu: + image: nillion/nilai-vllm:latest + container_name: nilai-gemma_4b_gpu + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + ipc: host + ulimits: + memlock: -1 + stack: 67108864 + env_file: + - .env + restart: unless-stopped + depends_on: + etcd: + condition: service_healthy + command: > + --model google/gemma-3-4b-it + --gpu-memory-utilization 0.7 + --max-model-len 8192 + --max-num-batched-tokens 8192 + --tensor-parallel-size 1 + --enable-auto-tool-choice + --tool-call-parser llama3_json + --uvicorn-log-level warning + --dtype half + environment: + - SVC_HOST=gemma_4b_gpu + - SVC_PORT=8000 + - ETCD_HOST=etcd + - ETCD_PORT=2379 + - TOOL_SUPPORT=true + - MULTIMODAL_SUPPORT=true + - VLLM_USE_V1=1 + - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 + - CUDA_LAUNCH_BLOCKING=1 + volumes: + - hugging_face_models:/root/.cache/huggingface + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + retries: 3 + start_period: 60s + timeout: 10s +volumes: + hugging_face_models: diff --git a/tests/e2e/config.py b/tests/e2e/config.py index c49eff97..8e4bad7e 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -34,6 +34,7 @@ def api_key_getter(): ], "ci": [ "meta-llama/Llama-3.2-1B-Instruct", + "google/gemma-3-4b-it", ], } diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 5c72d690..d93a08c4 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -952,3 +952,51 @@ def make_request(): assert isinstance(error, openai.RateLimitError), ( "Rate limited responses should be RateLimitError" ) + + +def test_multimodal_with_web_search_e2e(client): + """Test that multimodal models with web search enabled return sources.""" + import base64 + + # Base64 encoded image (minimal for testing) + base64_image = "" + + try: + response = client.chat.completions.create( + model="google/gemma-3-4b-it", + messages=[ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What breed of dog is this and what are some interesting facts about this breed? Search the web for additional information." + }, + { + "type": "image_url", + "image_url": { + "url": base64_image + } + } + ] + } + ], + max_tokens=100, + temperature=0.0, + extra_body={"web_search": True}, + ) + + assert response.choices[0].message.content, "Response should contain content" + + # Check that sources are returned when web search is enabled + sources = getattr(response, "sources", None) + assert sources is not None, "Web search responses should have sources" + assert isinstance(sources, list), "Sources should be a list" + assert len(sources) > 0, "Sources should not be empty" + + except Exception as e: + raise Exception(f"Multimodal with web search test failed: {e}") From ae4422a4cc4a939909758985c71cc2f2fc777200 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 10:23:37 +0200 Subject: [PATCH 06/75] fix: ruff format --- tests/e2e/test_openai.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index d93a08c4..a0dd2058 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -957,46 +957,38 @@ def make_request(): def test_multimodal_with_web_search_e2e(client): """Test that multimodal models with web search enabled return sources.""" import base64 - + # Base64 encoded image (minimal for testing) base64_image = "" - + try: response = client.chat.completions.create( model="google/gemma-3-4b-it", messages=[ - { - "role": "system", - "content": "You are a helpful assistant." - }, + {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": [ { "type": "text", - "text": "What breed of dog is this and what are some interesting facts about this breed? Search the web for additional information." + "text": "What breed of dog is this and what are some interesting facts about this breed? Search the web for additional information.", }, - { - "type": "image_url", - "image_url": { - "url": base64_image - } - } - ] - } + {"type": "image_url", "image_url": {"url": base64_image}}, + ], + }, ], max_tokens=100, temperature=0.0, extra_body={"web_search": True}, ) - + assert response.choices[0].message.content, "Response should contain content" - + # Check that sources are returned when web search is enabled sources = getattr(response, "sources", None) assert sources is not None, "Web search responses should have sources" assert isinstance(sources, list), "Sources should be a list" assert len(sources) > 0, "Sources should not be empty" - + except Exception as e: raise Exception(f"Multimodal with web search test failed: {e}") From 7c8b61ec02a9d15ab76cb330d5af7d8a48ea3a12 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 10:25:57 +0200 Subject: [PATCH 07/75] refactor: remove unused import in e2e and unit tests --- tests/e2e/test_openai.py | 1 - tests/unit/__init__.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index a0dd2058..c35a8e05 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -956,7 +956,6 @@ def make_request(): def test_multimodal_with_web_search_e2e(client): """Test that multimodal models with web search enabled return sources.""" - import base64 # Base64 encoded image (minimal for testing) base64_image = "" diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 246d56ce..7e22d6dd 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -2,7 +2,6 @@ from nilai_common import ( SignedChatCompletion, - Message, ModelEndpoint, ModelMetadata, Usage, From 92e8ece6cf62f135d8b0f23028ca8c2783a26e68 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 10:41:46 +0200 Subject: [PATCH 08/75] test: add rate limit checks to multimodal chat completion tests --- tests/unit/nilai_api/routers/test_private.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index b19ee41c..094fafcd 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -229,6 +229,7 @@ def test_chat_completion_with_image_support( mocker.patch( "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance ) + mocker.patch("nilai_api.rate_limiting.check_rate_limit", return_value=None) multimodal_metadata = ModelMetadata( id="google/gemma-3-4b-it", @@ -275,8 +276,9 @@ def test_chat_completion_with_image_support( def test_chat_completion_with_image_unsupported_model( - mock_user, mock_user_manager, client + mock_user, mock_user_manager, mocker, client ): + mocker.patch("nilai_api.rate_limiting.check_rate_limit", return_value=None) response = client.post( "/v1/chat/completions", json={ @@ -303,7 +305,8 @@ def test_chat_completion_with_image_unsupported_model( assert "multimodal content" in response.json()["detail"] -def test_chat_completion_with_invalid_image_url(mock_user, mock_user_manager, client): +def test_chat_completion_with_invalid_image_url(mock_user, mock_user_manager, mocker, client): + mocker.patch("nilai_api.rate_limiting.check_rate_limit", return_value=None) response = client.post( "/v1/chat/completions", json={ From 27a91b972e25d4860dcbe67d4aac8502e6a5c547 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 10:43:08 +0200 Subject: [PATCH 09/75] fix: ruff format --- tests/unit/nilai_api/routers/test_private.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 094fafcd..a5cc489e 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -305,7 +305,9 @@ def test_chat_completion_with_image_unsupported_model( assert "multimodal content" in response.json()["detail"] -def test_chat_completion_with_invalid_image_url(mock_user, mock_user_manager, mocker, client): +def test_chat_completion_with_invalid_image_url( + mock_user, mock_user_manager, mocker, client +): mocker.patch("nilai_api.rate_limiting.check_rate_limit", return_value=None) response = client.post( "/v1/chat/completions", From 2fcce351a62310c05a54686f27ddb75ed1d603d9 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 11:50:33 +0200 Subject: [PATCH 10/75] refactor: update message model structure --- .../src/nilai_api/handlers/web_search.py | 18 ++++++++++++---- .../src/nilai_common/api_model.py | 21 +++++++++++-------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 01313ec9..31de093c 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -11,9 +11,9 @@ Source, WebSearchEnhancedMessages, WebSearchContext, - MessageContentItem, + Message, + TextPart, ) -from nilai_common import Message logger = logging.getLogger(__name__) @@ -153,6 +153,16 @@ async def perform_web_search_async(query: str) -> WebSearchContext: async def enhance_messages_with_web_search( messages: List[Message], query: str ) -> WebSearchEnhancedMessages: + """Enhance a list of messages with web search context.Collapse commentComment on line L155jcabrero commented on Aug 26, 2025 jcabreroon Aug 26, 2025MemberDeleted docstring?Write a replyResolve commentCode has comments. Press enter to view. + + Args: + messages: List of conversation messages to enhance + query: Search query to retrieve web search results for + + Returns: + WebSearchEnhancedMessages containing the original messages with web search + context prepended as a system message, along with source information + """ ctx = await perform_web_search_async(query) query_source = Source(source="search_query", content=query) @@ -165,11 +175,11 @@ async def enhance_messages_with_web_search( last = messages[-1] items = ( - [MessageContentItem(type="text", text=last.content)] + [TextPart(type="text", text=last.content)] if isinstance(last.content, str) else list(last.content) ) - items.append(MessageContentItem(type="text", text=web_search_context)) + items.append(TextPart(type="text", text=web_search_context)) enhanced_messages = list(messages) enhanced_messages[-1] = Message(role="user", content=items) diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 94880219..c449d11f 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -2,26 +2,29 @@ from typing import Annotated, Iterable, List, Literal, Optional, Union -from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat import ChatCompletionToolParam from openai.types.chat.chat_completion import Choice as OpenaAIChoice from pydantic import BaseModel, Field -class ImageUrl(BaseModel): +class ImageURL(BaseModel): url: str - detail: Optional[str] = "auto" + detail: Literal["auto", "low", "high"] = "auto" +class ImagePart(BaseModel): + type: Literal["image_url"] + image_url: ImageURL -class MessageContentItem(BaseModel): - type: Literal["text", "image_url"] - text: Optional[str] = None - image_url: Optional[ImageUrl] = None +class TextPart(BaseModel): + type: Literal["text"] + text: str +ContentPart = Union[TextPart, ImagePart] -class Message(BaseModel): +class Message(ChatCompletionMessage): role: Literal["system", "user", "assistant", "tool"] - content: Union[str, List[MessageContentItem]] + content: Union[str, List[ContentPart]] class Choice(OpenaAIChoice): From 3e3cc563c32df5987e471b5a2c5696541d0ba6aa Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 11:54:25 +0200 Subject: [PATCH 11/75] fix: ruff format --- packages/nilai-common/src/nilai_common/api_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index c449d11f..7dd948e9 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -12,16 +12,20 @@ class ImageURL(BaseModel): url: str detail: Literal["auto", "low", "high"] = "auto" + class ImagePart(BaseModel): type: Literal["image_url"] image_url: ImageURL + class TextPart(BaseModel): type: Literal["text"] text: str + ContentPart = Union[TextPart, ImagePart] + class Message(ChatCompletionMessage): role: Literal["system", "user", "assistant", "tool"] content: Union[str, List[ContentPart]] From 2a9669c2bd269d5ea838ab8fb49ccdf2119a0f8c Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 11:57:33 +0200 Subject: [PATCH 12/75] test: enhance chat completion tests with multimodal model integration --- tests/unit/nilai_api/routers/test_private.py | 28 +++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index a5cc489e..8e9471a8 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -209,7 +209,7 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien def test_chat_completion_with_image_support( - mock_user, mock_user_manager, mocker, client + mock_user, mock_user_manager, mock_state, mocker, client ): mocker.patch("openai.api_key", new="test-api-key") from openai.types.chat import ChatCompletion @@ -229,7 +229,6 @@ def test_chat_completion_with_image_support( mocker.patch( "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance ) - mocker.patch("nilai_api.rate_limiting.check_rate_limit", return_value=None) multimodal_metadata = ModelMetadata( id="google/gemma-3-4b-it", @@ -276,9 +275,8 @@ def test_chat_completion_with_image_support( def test_chat_completion_with_image_unsupported_model( - mock_user, mock_user_manager, mocker, client + mock_user, mock_user_manager, mock_state, mocker, client ): - mocker.patch("nilai_api.rate_limiting.check_rate_limit", return_value=None) response = client.post( "/v1/chat/completions", json={ @@ -306,9 +304,27 @@ def test_chat_completion_with_image_unsupported_model( def test_chat_completion_with_invalid_image_url( - mock_user, mock_user_manager, mocker, client + mock_user, mock_user_manager, mock_state, mocker, client ): - mocker.patch("nilai_api.rate_limiting.check_rate_limit", return_value=None) + from nilai_common import ModelMetadata, ModelEndpoint + + multimodal_metadata = ModelMetadata( + id="google/gemma-3-4b-it", + name="google/gemma-3-4b-it", + version="1.0", + description="Multimodal model", + author="Google", + license="Apache 2.0", + source="https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct", + supported_features=["chat_completion"], + tool_support=False, + multimodal_support=True, + ) + multimodal_endpoint = ModelEndpoint( + url="http://test-model-url", metadata=multimodal_metadata + ) + + mocker.patch.object(state, "get_model", return_value=multimodal_endpoint) response = client.post( "/v1/chat/completions", json={ From 3c629522ed7f9fb6e8d711c7e400464a6a7e1979 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 12:28:50 +0200 Subject: [PATCH 13/75] fix: web search + multimodal with 3 sources --- .../src/nilai_api/handlers/web_search.py | 46 +++++++++++++------ 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 31de093c..127517a6 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -15,6 +15,8 @@ TextPart, ) +from .image_support import multimodal_check + logger = logging.getLogger(__name__) _BRAVE_API_HEADERS = { @@ -151,7 +153,7 @@ async def perform_web_search_async(query: str) -> WebSearchContext: async def enhance_messages_with_web_search( - messages: List[Message], query: str + messages: List[Message], query: str, multimodal: bool = False ) -> WebSearchEnhancedMessages: """Enhance a list of messages with web search context.Collapse commentComment on line L155jcabrero commented on Aug 26, 2025 jcabreroon Aug 26, 2025MemberDeleted docstring?Write a replyResolve commentCode has comments. Press enter to view. @@ -173,16 +175,23 @@ async def enhance_messages_with_web_search( web_search_context = f"\n\nWeb search results:\n{ctx.prompt}" - last = messages[-1] - items = ( - [TextPart(type="text", text=last.content)] - if isinstance(last.content, str) - else list(last.content) - ) - items.append(TextPart(type="text", text=web_search_context)) - - enhanced_messages = list(messages) - enhanced_messages[-1] = Message(role="user", content=items) + if multimodal: + sys_idx = next((i for i, m in enumerate(enhanced_messages) if m.role == "system"), None) + if sys_idx is not None: + existing_message = enhanced_messages[sys_idx] + existing_content = existing_message.content + if isinstance(existing_content, str): + merged_content = f"{web_search_context}\n\n{existing_content}" if existing_content else web_search_context + else: + text_parts = [p.text for p in existing_content if isinstance(p, TextPart)] + merged_content = f"{web_search_context}\n\n" + "\n".join(text_parts) if text_parts else web_search_context + enhanced_messages[sys_idx] = Message(role="system", content=merged_content) + else: + system_ctx_message = Message(role="system", content=web_search_context) + enhanced_messages = [system_ctx_message] + list(messages) + else: + system_ctx_message = Message(role="system", content=web_search_context) + enhanced_messages = [system_ctx_message] + list(messages) return WebSearchEnhancedMessages( messages=enhanced_messages, sources=[query_source] + ctx.sources @@ -255,7 +264,14 @@ async def handle_web_search( user_query = "" for message in reversed(req_messages): if message.role == "user": - user_query = message.content + if isinstance(message.content, str): + user_query = message.content + else: + parts = [] + for part in message.content: + if isinstance(part, TextPart): + parts.append(part.text) + user_query = "\n".join(parts).strip() break if not user_query: return WebSearchEnhancedMessages(messages=req_messages, sources=[]) @@ -263,7 +279,11 @@ async def handle_web_search( concise_query = await generate_search_query_from_llm( user_query, model_name, client ) - return await enhance_messages_with_web_search(req_messages, concise_query) + + is_multimodal = multimodal_check(req_messages).has_multimodal + return await enhance_messages_with_web_search( + req_messages, concise_query, multimodal=is_multimodal + ) except Exception: logger.warning("Web search enhancement failed") return WebSearchEnhancedMessages(messages=req_messages, sources=[]) From a52ec27ea61ab4a75fce8047be670ff7b7cce928 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 12:34:16 +0200 Subject: [PATCH 14/75] chore: stop tracking docker/compose/docker-compose.gemma-4b-gpu.ci.yml and ignore it --- .gitignore | 1 + .../docker-compose.gemma-4b-gpu.ci.yml | 51 ----- .../src/nilai_api/handlers/web_search.py | 22 ++- tests/e2e/test_openai.py | 39 ---- tests/unit/nilai_api/routers/test_private.py | 181 ------------------ 5 files changed, 18 insertions(+), 276 deletions(-) delete mode 100644 docker/compose/docker-compose.gemma-4b-gpu.ci.yml diff --git a/.gitignore b/.gitignore index 2a8f56cc..e2d624f0 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,4 @@ private_key.key.lock development-compose.yml production-compose.yml +docker/compose/docker-compose.gemma-4b-gpu.ci.yml diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml deleted file mode 100644 index 81203986..00000000 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ /dev/null @@ -1,51 +0,0 @@ -services: - gemma_4b_gpu: - image: nillion/nilai-vllm:latest - container_name: nilai-gemma_4b_gpu - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: all - capabilities: [gpu] - ipc: host - ulimits: - memlock: -1 - stack: 67108864 - env_file: - - .env - restart: unless-stopped - depends_on: - etcd: - condition: service_healthy - command: > - --model google/gemma-3-4b-it - --gpu-memory-utilization 0.7 - --max-model-len 8192 - --max-num-batched-tokens 8192 - --tensor-parallel-size 1 - --enable-auto-tool-choice - --tool-call-parser llama3_json - --uvicorn-log-level warning - --dtype half - environment: - - SVC_HOST=gemma_4b_gpu - - SVC_PORT=8000 - - ETCD_HOST=etcd - - ETCD_PORT=2379 - - TOOL_SUPPORT=true - - MULTIMODAL_SUPPORT=true - - VLLM_USE_V1=1 - - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 - - CUDA_LAUNCH_BLOCKING=1 - volumes: - - hugging_face_models:/root/.cache/huggingface - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8000/health"] - interval: 30s - retries: 3 - start_period: 60s - timeout: 10s -volumes: - hugging_face_models: diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 127517a6..003f02a5 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -176,15 +176,27 @@ async def enhance_messages_with_web_search( web_search_context = f"\n\nWeb search results:\n{ctx.prompt}" if multimodal: - sys_idx = next((i for i, m in enumerate(enhanced_messages) if m.role == "system"), None) + sys_idx = next( + (i for i, m in enumerate(enhanced_messages) if m.role == "system"), None + ) if sys_idx is not None: existing_message = enhanced_messages[sys_idx] existing_content = existing_message.content if isinstance(existing_content, str): - merged_content = f"{web_search_context}\n\n{existing_content}" if existing_content else web_search_context + merged_content = ( + f"{web_search_context}\n\n{existing_content}" + if existing_content + else web_search_context + ) else: - text_parts = [p.text for p in existing_content if isinstance(p, TextPart)] - merged_content = f"{web_search_context}\n\n" + "\n".join(text_parts) if text_parts else web_search_context + text_parts = [ + p.text for p in existing_content if isinstance(p, TextPart) + ] + merged_content = ( + f"{web_search_context}\n\n" + "\n".join(text_parts) + if text_parts + else web_search_context + ) enhanced_messages[sys_idx] = Message(role="system", content=merged_content) else: system_ctx_message = Message(role="system", content=web_search_context) @@ -279,7 +291,7 @@ async def handle_web_search( concise_query = await generate_search_query_from_llm( user_query, model_name, client ) - + is_multimodal = multimodal_check(req_messages).has_multimodal return await enhance_messages_with_web_search( req_messages, concise_query, multimodal=is_multimodal diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index c35a8e05..5c72d690 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -952,42 +952,3 @@ def make_request(): assert isinstance(error, openai.RateLimitError), ( "Rate limited responses should be RateLimitError" ) - - -def test_multimodal_with_web_search_e2e(client): - """Test that multimodal models with web search enabled return sources.""" - - # Base64 encoded image (minimal for testing) - base64_image = "" - - try: - response = client.chat.completions.create( - model="google/gemma-3-4b-it", - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What breed of dog is this and what are some interesting facts about this breed? Search the web for additional information.", - }, - {"type": "image_url", "image_url": {"url": base64_image}}, - ], - }, - ], - max_tokens=100, - temperature=0.0, - extra_body={"web_search": True}, - ) - - assert response.choices[0].message.content, "Response should contain content" - - # Check that sources are returned when web search is enabled - sources = getattr(response, "sources", None) - assert sources is not None, "Web search responses should have sources" - assert isinstance(sources, list), "Sources should be a list" - assert len(sources) > 0, "Sources should not be empty" - - except Exception as e: - raise Exception(f"Multimodal with web search test failed: {e}") diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 8e9471a8..9cf0ba4e 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -166,184 +166,3 @@ def test_get_models(mock_user, mock_user_manager, mock_state, client): ) assert response.status_code == 200 assert response.json() == [model_metadata.model_dump()] - - -def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, client): - mocker.patch("openai.api_key", new="test-api-key") - from openai.types.chat import ChatCompletion - - data = RESPONSE.model_dump() - data.pop("signature") - data.pop("sources", None) - response_data = ChatCompletion(**data) - # Patch nilai_api.routers.private.AsyncOpenAI to return a mock instance with chat.completions.create as an AsyncMock - mock_chat_completions = MagicMock() - mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) - mock_chat = MagicMock() - mock_chat.completions = mock_chat_completions - mock_async_openai_instance = MagicMock() - mock_async_openai_instance.chat = mock_chat - mocker.patch( - "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance - ) - response = client.post( - "/v1/chat/completions", - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, - ], - }, - headers={"Authorization": "Bearer test-api-key"}, - ) - assert response.status_code == 200 - assert "usage" in response.json() - assert response.json()["usage"] == { - "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150, - "completion_tokens_details": None, - "prompt_tokens_details": None, - } - - -def test_chat_completion_with_image_support( - mock_user, mock_user_manager, mock_state, mocker, client -): - mocker.patch("openai.api_key", new="test-api-key") - from openai.types.chat import ChatCompletion - from nilai_common import ModelMetadata, ModelEndpoint - - data = RESPONSE.model_dump() - data.pop("signature") - data.pop("sources", None) - response_data = ChatCompletion(**data) - - mock_chat_completions = MagicMock() - mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) - mock_chat = MagicMock() - mock_chat.completions = mock_chat_completions - mock_async_openai_instance = MagicMock() - mock_async_openai_instance.chat = mock_chat - mocker.patch( - "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance - ) - - multimodal_metadata = ModelMetadata( - id="google/gemma-3-4b-it", - name="google/gemma-3-4b-it", - version="1.0", - description="Multimodal model", - author="Google", - license="Apache 2.0", - source="https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct", - supported_features=["chat_completion"], - tool_support=False, - multimodal_support=True, - ) - multimodal_endpoint = ModelEndpoint( - url="http://test-model-url", metadata=multimodal_metadata - ) - - mocker.patch.object(state, "get_model", return_value=multimodal_endpoint) - - response = client.post( - "/v1/chat/completions", - json={ - "model": "google/gemma-3-4b-it", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "" - }, - }, - ], - }, - ], - }, - headers={"Authorization": "Bearer test-api-key"}, - ) - assert response.status_code == 200 - assert "usage" in response.json() - - -def test_chat_completion_with_image_unsupported_model( - mock_user, mock_user_manager, mock_state, mocker, client -): - response = client.post( - "/v1/chat/completions", - json={ - "model": "google/gemma-3-4b-it", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "" - }, - }, - ], - }, - ], - }, - headers={"Authorization": "Bearer test-api-key"}, - ) - assert response.status_code == 400 - assert "multimodal content" in response.json()["detail"] - - -def test_chat_completion_with_invalid_image_url( - mock_user, mock_user_manager, mock_state, mocker, client -): - from nilai_common import ModelMetadata, ModelEndpoint - - multimodal_metadata = ModelMetadata( - id="google/gemma-3-4b-it", - name="google/gemma-3-4b-it", - version="1.0", - description="Multimodal model", - author="Google", - license="Apache 2.0", - source="https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct", - supported_features=["chat_completion"], - tool_support=False, - multimodal_support=True, - ) - multimodal_endpoint = ModelEndpoint( - url="http://test-model-url", metadata=multimodal_metadata - ) - - mocker.patch.object(state, "get_model", return_value=multimodal_endpoint) - response = client.post( - "/v1/chat/completions", - json={ - "model": "google/gemma-3-4b-it", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": {"url": "https://example.com/image.jpg"}, - }, - ], - }, - ], - }, - headers={"Authorization": "Bearer test-api-key"}, - ) - assert response.status_code == 400 - assert "base64 data URLs" in response.json()["detail"] From ad4fa11cb554368ee89d84060543641ca2bcda88 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 12:37:40 +0200 Subject: [PATCH 15/75] refactor: clean up imports in tests --- nilai-api/src/nilai_api/handlers/web_search.py | 2 ++ tests/unit/nilai_api/routers/test_private.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 003f02a5..4f692d9d 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -175,6 +175,8 @@ async def enhance_messages_with_web_search( web_search_context = f"\n\nWeb search results:\n{ctx.prompt}" + enhanced_messages = list(messages) + if multimodal: sys_idx = next( (i for i, m in enumerate(enhanced_messages) if m.role == "system"), None diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 9cf0ba4e..fea83a0a 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -8,7 +8,7 @@ from nilai_common import AttestationReport from nilai_api.state import state -from ... import model_endpoint, model_metadata, response as RESPONSE +from ... import model_endpoint, model_metadata @pytest.mark.asyncio From f0e584865a99e71716aec4f62559dc63b17a1cfd Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 12:41:56 +0200 Subject: [PATCH 16/75] fix: add type ignore for role in Message class --- packages/nilai-common/src/nilai_common/api_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 7dd948e9..7def5a5e 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -27,7 +27,7 @@ class TextPart(BaseModel): class Message(ChatCompletionMessage): - role: Literal["system", "user", "assistant", "tool"] + role: Literal["system", "user", "assistant", "tool"] # type: ignore content: Union[str, List[ContentPart]] From 338a5ec54919ec3d3aeeaa503bd50950e86040b3 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 18:11:49 +0200 Subject: [PATCH 17/75] refactor: update Message class content type to use new ChatCompletion content parts --- .../src/nilai_common/api_model.py | 35 ++++++++----------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 7def5a5e..dc45e15d 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -5,30 +5,25 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat import ChatCompletionToolParam from openai.types.chat.chat_completion import Choice as OpenaAIChoice +from openai.types.chat.chat_completion_content_part_image_param import ( + ChatCompletionContentPartImageParam, +) +from openai.types.chat.chat_completion_content_part_text_param import ( + ChatCompletionContentPartTextParam, +) from pydantic import BaseModel, Field -class ImageURL(BaseModel): - url: str - detail: Literal["auto", "low", "high"] = "auto" - - -class ImagePart(BaseModel): - type: Literal["image_url"] - image_url: ImageURL - - -class TextPart(BaseModel): - type: Literal["text"] - text: str - - -ContentPart = Union[TextPart, ImagePart] - - class Message(ChatCompletionMessage): - role: Literal["system", "user", "assistant", "tool"] # type: ignore - content: Union[str, List[ContentPart]] + role: Literal["system", "user", "assistant", "tool"] + content: Union[ + str, + List[ + Union[ + ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam + ] + ], + ] class Choice(OpenaAIChoice): From e85a7117e8b233d2b1348c68cbbfc1c83dd59a11 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 18:12:02 +0200 Subject: [PATCH 18/75] feat: add content extractor utility for processing text and image content from chat completions --- .../src/nilai_api/utils/content_extractor.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 nilai-api/src/nilai_api/utils/content_extractor.py diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py new file mode 100644 index 00000000..f9d42dbd --- /dev/null +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -0,0 +1,35 @@ +from typing import Union, List +from openai.types.chat.chat_completion_content_part_text_param import ( + ChatCompletionContentPartTextParam, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ChatCompletionContentPartImageParam, +) + + +def extract_text_content( + content: Union[ + str, + List[ + Union[ + ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam + ] + ], + ], +) -> str: + """ + Extract text content from a message content field. + + Args: + content: Either a string or a list of content parts + + Returns: + str: The extracted text content, or empty string if no text content found + """ + if isinstance(content, str): + return content + elif isinstance(content, list): + for part in content: + if part.type == "text": + return part.text + return "" From 758d809826329590ec15c9069370045ce8b97b91 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 18:17:46 +0200 Subject: [PATCH 19/75] refactor: improve web search handling and enhance message context with user queries --- .../src/nilai_api/handlers/web_search.py | 155 +++++++----------- nilai-api/src/nilai_api/routers/private.py | 16 +- 2 files changed, 70 insertions(+), 101 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 4f692d9d..7df496d0 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -1,22 +1,20 @@ import logging from functools import lru_cache -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional import httpx from fastapi import HTTPException, status from nilai_api.config import WEB_SEARCH_SETTINGS +from nilai_api.utils.content_extractor import get_last_user_query from nilai_common.api_model import ( SearchResult, Source, WebSearchEnhancedMessages, WebSearchContext, Message, - TextPart, ) -from .image_support import multimodal_check - logger = logging.getLogger(__name__) _BRAVE_API_HEADERS = { @@ -64,26 +62,36 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Missing BRAVE_SEARCH_API key in environment", ) + q = " ".join(query.split()) q = " ".join(q.split()[:50])[:400] + params = {**_BRAVE_API_PARAMS_BASE, "q": q} headers = { **_BRAVE_API_HEADERS, "X-Subscription-Token": WEB_SEARCH_SETTINGS.api_key, } + client = _get_http_client() resp = await client.get( WEB_SEARCH_SETTINGS.api_path, headers=headers, params=params ) + if resp.status_code >= 400: logger.error("Brave API error: %s - %s", resp.status_code, resp.text) - error = HTTPException( + raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Web search failed, service currently unavailable", ) - error.status_code = 503 - raise error - return resp.json() + + try: + return resp.json() + except Exception: + logger.exception("Failed to parse Brave API JSON") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Web search failed: invalid response from provider", + ) def _parse_brave_results(data: Dict[str, Any]) -> List[SearchResult]: @@ -97,6 +105,7 @@ def _parse_brave_results(data: Dict[str, Any]) -> List[SearchResult]: """ web_block = data.get("web", {}) if isinstance(data, dict) else {} raw_results = web_block.get("results", []) if isinstance(web_block, dict) else [] + results: List[SearchResult] = [] for item in raw_results: if not isinstance(item, dict): @@ -104,14 +113,9 @@ def _parse_brave_results(data: Dict[str, Any]) -> List[SearchResult]: title = item.get("title", "")[:200] body = item.get("description") or item.get("snippet") or item.get("body", "") url = item.get("url") or item.get("link") or item.get("href", "") + if title and body and url: - results.append( - SearchResult( - title=title, - body=str(body)[:500], - url=str(url)[:500], - ) - ) + results.append(SearchResult(title=title, body=body, url=url)) return results @@ -147,13 +151,13 @@ async def perform_web_search_async(query: str) -> WebSearchContext: for idx, r in enumerate(results, start=1) ] prompt = "\n".join(lines) - sources = [Source(source=r.url, content=r.body) for r in results] + sources = [Source(source=r.url, content=r.body) for r in results] return WebSearchContext(prompt=prompt, sources=sources) async def enhance_messages_with_web_search( - messages: List[Message], query: str, multimodal: bool = False + messages: List[Message], query: str ) -> WebSearchEnhancedMessages: """Enhance a list of messages with web search context.Collapse commentComment on line L155jcabrero commented on Aug 26, 2025 jcabreroon Aug 26, 2025MemberDeleted docstring?Write a replyResolve commentCode has comments. Press enter to view. @@ -166,95 +170,67 @@ async def enhance_messages_with_web_search( context prepended as a system message, along with source information """ ctx = await perform_web_search_async(query) - query_source = Source(source="search_query", content=query) - - if not messages or messages[-1].role != "user": - return WebSearchEnhancedMessages( - messages=messages, sources=[query_source] + ctx.sources - ) - - web_search_context = f"\n\nWeb search results:\n{ctx.prompt}" - - enhanced_messages = list(messages) + query_source = Source(source="web_search_query", content=query) + + system_content = ( + f'You have access to the following web search results for the query: "{query}"\n\n' + "Use this information to provide accurate and up-to-date answers. " + "Cite the sources when appropriate.\n\n" + "Web Search Results:\n" + f"{ctx.prompt}\n\n" + "Please provide a comprehensive answer based on the search results above." + ) - if multimodal: - sys_idx = next( - (i for i, m in enumerate(enhanced_messages) if m.role == "system"), None - ) - if sys_idx is not None: - existing_message = enhanced_messages[sys_idx] - existing_content = existing_message.content - if isinstance(existing_content, str): - merged_content = ( - f"{web_search_context}\n\n{existing_content}" - if existing_content - else web_search_context - ) - else: - text_parts = [ - p.text for p in existing_content if isinstance(p, TextPart) - ] - merged_content = ( - f"{web_search_context}\n\n" + "\n".join(text_parts) - if text_parts - else web_search_context - ) - enhanced_messages[sys_idx] = Message(role="system", content=merged_content) - else: - system_ctx_message = Message(role="system", content=web_search_context) - enhanced_messages = [system_ctx_message] + list(messages) - else: - system_ctx_message = Message(role="system", content=web_search_context) - enhanced_messages = [system_ctx_message] + list(messages) + system_message = Message(role="system", content=system_content) + enhanced = list(messages) + [system_message] return WebSearchEnhancedMessages( - messages=enhanced_messages, sources=[query_source] + ctx.sources + messages=enhanced, + sources=[query_source] + ctx.sources, ) async def generate_search_query_from_llm( user_message: str, model_name: str, client ) -> str: - system_prompt = """ - You are given a user question. Your task is to generate a concise web search query that will best retrieve information to answer the question. If the user’s question is already optimal, simply repeat it as the query. This is essentially summarization, paraphrasing, and key term extraction. - - - Do not add guiding elements or assumptions that the user did not explicitly request. - - Do not answer the query. - - The query must contain at least 10 words. - - Output only the search query. - - ### Example - - **User:** Who won the Roland Garros Open in 2024? Just reply with the winner's name. - **Search query:** Roland Garros 2024 tennis tournament winner men women champion """ + Use the LLM to produce a concise, high-recall search query. + """ + system_prompt = ( + "You are given a user question. Generate a concise web search query that will best retrieve information " + "to answer the question. If the user’s question is already optimal, repeat it exactly.\n" + "- Do not add assumptions not present in the question.\n" + "- Do not answer the question.\n" + "- The query must contain at least 10 words.\n" + "Output only the search query." + ) + messages = [ Message(role="system", content=system_prompt), Message(role="user", content=user_message), ] + req = { "model": model_name, "messages": [m.model_dump() for m in messages], "max_tokens": 150, } + try: response = await client.chat.completions.create(**req) except Exception as exc: - raise RuntimeError(f"Failed to generate search query: {str(exc)}") from exc - - if not response.choices: - raise RuntimeError("LLM returned an empty search query") + raise RuntimeError(f"Failed to generate search query: {exc}") from exc try: - content = response.choices[0].message.content.strip() - except (AttributeError, IndexError, TypeError) as exc: - raise RuntimeError(f"Invalid response structure from LLM: {str(exc)}") from exc + choices = getattr(response, "choices", None) or [] + msg = choices[0].message + content = (getattr(msg, "content", None) or "").strip() + except Exception as exc: + raise RuntimeError(f"Invalid response structure from LLM: {exc}") from exc if not content: raise RuntimeError("LLM returned an empty search query") - logger.debug("Generated search query: %s", content) - return content @@ -275,29 +251,18 @@ async def handle_web_search( WebSearchEnhancedMessages with web search context added, or original messages if no user query is found or search fails """ - user_query = "" - for message in reversed(req_messages): - if message.role == "user": - if isinstance(message.content, str): - user_query = message.content - else: - parts = [] - for part in message.content: - if isinstance(part, TextPart): - parts.append(part.text) - user_query = "\n".join(parts).strip() - break + user_query = get_last_user_query(req_messages) if not user_query: return WebSearchEnhancedMessages(messages=req_messages, sources=[]) + try: concise_query = await generate_search_query_from_llm( user_query, model_name, client ) + return await enhance_messages_with_web_search(req_messages, concise_query) + + except HTTPException: + return WebSearchEnhancedMessages(messages=req_messages, sources=[]) - is_multimodal = multimodal_check(req_messages).has_multimodal - return await enhance_messages_with_web_search( - req_messages, concise_query, multimodal=is_multimodal - ) except Exception: - logger.warning("Web search enhancement failed") return WebSearchEnhancedMessages(messages=req_messages, sources=[]) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 6c3bb6a6..fe54250e 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -6,7 +6,7 @@ from nilai_api.attestation import get_attestation_report from nilai_api.handlers.nilrag import handle_nilrag from nilai_api.handlers.web_search import handle_web_search -from nilai_api.handlers.image_support import multimodal_check +from nilai_api.utils.content_extractor import has_multimodal_content from fastapi import APIRouter, Body, Depends, HTTPException, status, Request from fastapi.responses import StreamingResponse @@ -213,15 +213,18 @@ async def chat_completion( detail="Model does not support tool usage, remove tools from request", ) - multimodal_result = multimodal_check(req.messages) - if multimodal_result.has_multimodal: + has_multimodal = has_multimodal_content(req.messages) + if has_multimodal: if not endpoint.metadata.multimodal_support: raise HTTPException( status_code=400, detail="Model does not support multimodal content, remove image inputs from request", ) - if multimodal_result.error: - raise HTTPException(status_code=400, detail=multimodal_result.error) + if req.web_search: + raise HTTPException( + status_code=400, + detail="Web search is not supported with multimodal (image) content. Use text-only input for web search.", + ) model_url = endpoint.url + "/v1/" @@ -236,7 +239,8 @@ async def chat_completion( messages = req.messages sources: Optional[List[Source]] = None - if req.web_search: + + if req.web_search and not has_multimodal: web_search_result = await handle_web_search(messages, model_name, client) messages = web_search_result.messages sources = web_search_result.sources From fc9ea5d78c5bfccf0ad26f591621600bf2aa8a60 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 18:18:01 +0200 Subject: [PATCH 20/75] refactor: integrate content extraction into user query handling and enhance system message updates --- nilai-api/src/nilai_api/handlers/nilrag.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index b55ac18b..54582ae3 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -5,7 +5,7 @@ from nilai_common import ChatRequest, Message from fastapi import HTTPException, status from sentence_transformers import SentenceTransformer -from typing import Union +from nilai_api.utils.content_extractor import extract_text_content logger = logging.getLogger(__name__) @@ -66,10 +66,10 @@ async def handle_nilrag(req: ChatRequest): query = None for message in req.messages: if message.role == "user": - query = message.content + query = extract_text_content(message.content) break - if query is None: + if not query: raise HTTPException(status_code=400, detail="No user query found") # Get number of chunks to include @@ -92,9 +92,11 @@ async def handle_nilrag(req: ChatRequest): status_code=status.HTTP_400_BAD_REQUEST, detail="system message is empty", ) - message.content += ( - relevant_context # Append the context to the system message - ) + + if isinstance(message.content, str): + message.content += relevant_context + elif isinstance(message.content, list): + message.content.append({"type": "text", "text": relevant_context}) break else: # If no system message exists, add one From 758dc0d478cfb7c813ad5b8d1d1f4dd482717a42 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 18:18:29 +0200 Subject: [PATCH 21/75] feat: add functions to handle multimodal content and extract the last user query from messages --- .../src/nilai_api/utils/content_extractor.py | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index f9d42dbd..e77bc256 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -1,10 +1,12 @@ -from typing import Union, List +import logging +from typing import Union, List, Optional from openai.types.chat.chat_completion_content_part_text_param import ( ChatCompletionContentPartTextParam, ) from openai.types.chat.chat_completion_content_part_image_param import ( ChatCompletionContentPartImageParam, ) +from nilai_common import Message def extract_text_content( @@ -33,3 +35,31 @@ def extract_text_content( if part.type == "text": return part.text return "" + + +def has_multimodal_content(messages: List[Message]) -> bool: + """Check if any message contains multimodal content (image_url parts).""" + last_message = messages[-1] + last_message_content = last_message.content + is_multimodal = isinstance(last_message_content, list) and any( + isinstance(item, dict) and item.get("type") == "image_url" + for item in last_message_content + ) + return is_multimodal + + +def get_last_user_query(messages: List[Message]) -> Optional[str]: + """ + Walk from the end to find the most recent user-authored content, and + extract text from either a string or a multimodal content list. + """ + for msg in reversed(messages): + if getattr(msg, "role", None) == "user": + content = getattr(msg, "content", None) + try: + text = extract_text_content(content) + except Exception: + text = None + if text: + return text.strip() + return None From eaf2e9669132cf573762ae2e7ee174b5368443d1 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 18:35:19 +0200 Subject: [PATCH 22/75] refactor: remove deprecated image support handler and enhance multimodal completion tests --- .../src/nilai_api/handlers/image_support.py | 92 -------------- nilai-api/src/nilai_api/handlers/nilrag.py | 1 + tests/e2e/test_openai.py | 44 +++++++ tests/unit/nilai_api/routers/test_private.py | 114 ++++++++++++++++++ 4 files changed, 159 insertions(+), 92 deletions(-) delete mode 100644 nilai-api/src/nilai_api/handlers/image_support.py diff --git a/nilai-api/src/nilai_api/handlers/image_support.py b/nilai-api/src/nilai_api/handlers/image_support.py deleted file mode 100644 index f353912c..00000000 --- a/nilai-api/src/nilai_api/handlers/image_support.py +++ /dev/null @@ -1,92 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Any -from fastapi import HTTPException -from nilai_common import Message - - -@dataclass(frozen=True) -class MultimodalCheck: - has_multimodal: bool - error: Optional[str] = None - - -def _extract_url(image_url_field: Any) -> Optional[str]: - """ - Support both object-with-attr and dict-like shapes. - Returns the URL string or None. - """ - if image_url_field is None: - return None - - url = getattr(image_url_field, "url", None) - if url is not None: - return url - if isinstance(image_url_field, dict): - return image_url_field.get("url") - return None - - -def multimodal_check(messages: List[Message]) -> MultimodalCheck: - """ - Single-pass check: - - detect if any part is type=='image_url' - - validate that image_url.url exists and is a base64 data URL - Returns: - MultimodalCheck(has_multimodal: bool, error: Optional[str]) - """ - has_mm = False - - for m in messages: - content = getattr(m, "content", None) or [] - for item in content: - if getattr(item, "type", None) == "image_url": - has_mm = True - iu = getattr(item, "image_url", None) - url = _extract_url(iu) - if not url: - return MultimodalCheck( - True, "image_url.url is required for image_url parts" - ) - if not (url.startswith("data:image/") and ";base64," in url): - return MultimodalCheck( - True, - "Only base64 data URLs are allowed for images (data:image/...;base64,...)", - ) - - return MultimodalCheck(has_mm, None) - - -def has_multimodal_content( - messages: List[Message], precomputed: Optional[MultimodalCheck] = None -) -> bool: - """ - Check if any message contains multimodal content (image_url parts). - - Args: - messages: List of messages to check - precomputed: Optional precomputed result from multimodal_check() to avoid re-iterating - - Returns: - True if any message contains image_url parts, False otherwise - """ - res = precomputed or multimodal_check(messages) - return res.has_multimodal - - -def validate_multimodal_content( - messages: List[Message], precomputed: Optional[MultimodalCheck] = None -) -> None: - """ - Validate that multimodal content (image_url parts) follows the required format. - - Args: - messages: List of messages to validate - precomputed: Optional precomputed result from multimodal_check() to avoid re-iterating - - Raises: - HTTPException(400): When image_url parts don't have required URL or use invalid format - (only base64 data URLs are allowed: data:image/...;base64,...) - """ - res = precomputed or multimodal_check(messages) - if res.error: - raise HTTPException(status_code=400, detail=res.error) diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index 54582ae3..09e8fae0 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -1,4 +1,5 @@ import logging +from typing import Union import nilrag diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 5c72d690..76ebc322 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -952,3 +952,47 @@ def make_request(): assert isinstance(error, openai.RateLimitError), ( "Rate limited responses should be RateLimitError" ) + + +def test_multimodal_completion(client): + """Test basic multimodal completion with image content.""" + try: + response = client.chat.completions.create( + model=test_models[0], + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + max_tokens=50, + ) + + assert isinstance(response, ChatCompletion), ( + "Response should be a ChatCompletion object" + ) + assert len(response.choices) > 0, "Response should contain at least one choice" + + content = response.choices[0].message.content + assert content, "Response should contain content" + + print(f"\nMultimodal response: {content[:100]}...") + + assert response.usage, "Response should contain usage data" + assert response.usage.prompt_tokens > 0, ( + "Prompt tokens should be greater than 0" + ) + assert response.usage.completion_tokens > 0, ( + "Completion tokens should be greater than 0" + ) + + except Exception as e: + pytest.fail(f"Error testing multimodal completion: {str(e)}") diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index fea83a0a..c77b040e 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -166,3 +166,117 @@ def test_get_models(mock_user, mock_user_manager, mock_state, client): ) assert response.status_code == 200 assert response.json() == [model_metadata.model_dump()] + + +def test_web_search_with_multimodal_content_error( + mock_user, mock_user_manager, mock_state, client +): + """Test that web search with multimodal content returns 400 error.""" + response = client.post( + "/v1/chat/completions", + headers={"Authorization": "Bearer test-api-key"}, + json={ + "model": "ABC", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + "web_search": True, + }, + ) + assert response.status_code == 400 + assert ( + "Web search is not supported with multimodal (image) content" + in response.json()["detail"] + ) + + +def test_web_search_with_text_only_works( + mock_user, mock_user_manager, mock_state, client, mocker +): + """Test that web search with text-only content works normally.""" + + mock_web_search = mocker.patch("nilai_api.routers.private.handle_web_search") + mock_web_search.return_value = MagicMock(messages=[], sources=[]) + + mock_client = mocker.patch("nilai_api.routers.private.AsyncOpenAI") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Test response" + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_client.return_value.chat.completions.create.return_value = mock_response + + mocker.patch( + "nilai_api.routers.private.create_signed_chat_completion", + return_value={"test": "response"}, + ) + + response = client.post( + "/v1/chat/completions", + headers={"Authorization": "Bearer test-api-key"}, + json={ + "model": "ABC", + "messages": [{"role": "user", "content": "What is the latest AI news?"}], + "web_search": True, + }, + ) + + assert response.status_code == 200 + mock_web_search.assert_called_once() + + +def test_multimodal_completion( + mock_user, mock_user_manager, mock_state, client, mocker +): + """Test basic multimodal completion with image content.""" + mock_client = mocker.patch("nilai_api.routers.private.AsyncOpenAI") + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "This appears to be an image." + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 15 + mock_response.usage.completion_tokens = 8 + mock_response.usage.total_tokens = 23 + mock_client.return_value.chat.completions.create.return_value = mock_response + + mocker.patch( + "nilai_api.routers.private.create_signed_chat_completion", + return_value={"test": "response"}, + ) + + response = client.post( + "/v1/chat/completions", + headers={"Authorization": "Bearer test-api-key"}, + json={ + "model": "ABC", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + }, + ) + assert response.status_code == 200 + assert response.json() == {"test": "response"} From 7973642a1dd4f90f874bf8c09e7ab82c1cda9e2c Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 18:37:39 +0200 Subject: [PATCH 23/75] refactor: clean up unused imports in web search and content extractor modules --- nilai-api/src/nilai_api/handlers/web_search.py | 2 +- nilai-api/src/nilai_api/utils/content_extractor.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 7df496d0..0273f76a 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -1,6 +1,6 @@ import logging from functools import lru_cache -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any import httpx from fastapi import HTTPException, status diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index e77bc256..a89d9576 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -1,4 +1,3 @@ -import logging from typing import Union, List, Optional from openai.types.chat.chat_completion_content_part_text_param import ( ChatCompletionContentPartTextParam, From f9b71cd4c413faac452867a3c8c7c04ddb521d63 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Tue, 26 Aug 2025 18:46:53 +0200 Subject: [PATCH 24/75] feat: implement chat completion tests with image support and error handling for multimodal content --- tests/unit/nilai_api/routers/test_private.py | 152 ++++++++++--------- 1 file changed, 81 insertions(+), 71 deletions(-) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index c77b040e..4f5a67d3 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -8,7 +8,7 @@ from nilai_common import AttestationReport from nilai_api.state import state -from ... import model_endpoint, model_metadata +from ... import model_endpoint, model_metadata, response as RESPONSE @pytest.mark.asyncio @@ -168,115 +168,125 @@ def test_get_models(mock_user, mock_user_manager, mock_state, client): assert response.json() == [model_metadata.model_dump()] -def test_web_search_with_multimodal_content_error( - mock_user, mock_user_manager, mock_state, client -): - """Test that web search with multimodal content returns 400 error.""" +def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, client): + mocker.patch("openai.api_key", new="test-api-key") + from openai.types.chat import ChatCompletion + + data = RESPONSE.model_dump() + data.pop("signature") + data.pop("sources", None) + response_data = ChatCompletion(**data) + # Patch nilai_api.routers.private.AsyncOpenAI to return a mock instance with chat.completions.create as an AsyncMock + mock_chat_completions = MagicMock() + mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) + mock_chat = MagicMock() + mock_chat.completions = mock_chat_completions + mock_async_openai_instance = MagicMock() + mock_async_openai_instance.chat = mock_chat + mocker.patch( + "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance + ) response = client.post( "/v1/chat/completions", + json={ + "model": "meta-llama/Llama-3.2-1B-Instruct", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, + ], + }, headers={"Authorization": "Bearer test-api-key"}, + ) + assert response.status_code == 200 + assert "usage" in response.json() + assert response.json()["usage"] == { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "completion_tokens_details": None, + "prompt_tokens_details": None, + } + + +def test_chat_completion_image_web_search_error( + mock_user, mock_state, mock_user_manager, client +): + response = client.post( + "/v1/chat/completions", json={ - "model": "ABC", + "model": "google/gemma-3-4b-it", "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": [ - {"type": "text", "text": "What is this image?"}, + {"type": "text", "text": "What is in this image?"}, { "type": "image_url", "image_url": { - "url": "" + "url": "" }, }, ], - } + }, ], "web_search": True, }, - ) - assert response.status_code == 400 - assert ( - "Web search is not supported with multimodal (image) content" - in response.json()["detail"] - ) - - -def test_web_search_with_text_only_works( - mock_user, mock_user_manager, mock_state, client, mocker -): - """Test that web search with text-only content works normally.""" - - mock_web_search = mocker.patch("nilai_api.routers.private.handle_web_search") - mock_web_search.return_value = MagicMock(messages=[], sources=[]) - - mock_client = mocker.patch("nilai_api.routers.private.AsyncOpenAI") - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = "Test response" - mock_response.usage = MagicMock() - mock_response.usage.prompt_tokens = 10 - mock_response.usage.completion_tokens = 5 - mock_response.usage.total_tokens = 15 - mock_client.return_value.chat.completions.create.return_value = mock_response - - mocker.patch( - "nilai_api.routers.private.create_signed_chat_completion", - return_value={"test": "response"}, - ) - - response = client.post( - "/v1/chat/completions", headers={"Authorization": "Bearer test-api-key"}, - json={ - "model": "ABC", - "messages": [{"role": "user", "content": "What is the latest AI news?"}], - "web_search": True, - }, ) - - assert response.status_code == 200 - mock_web_search.assert_called_once() + assert response.status_code == 400 + assert "web_search" in response.json()["detail"].lower() -def test_multimodal_completion( - mock_user, mock_user_manager, mock_state, client, mocker +def test_chat_completion_with_image( + mock_user, mock_state, mock_user_manager, mocker, client ): - """Test basic multimodal completion with image content.""" - mock_client = mocker.patch("nilai_api.routers.private.AsyncOpenAI") - mock_response = MagicMock() - mock_response.choices = [MagicMock()] - mock_response.choices[0].message.content = "This appears to be an image." - mock_response.usage = MagicMock() - mock_response.usage.prompt_tokens = 15 - mock_response.usage.completion_tokens = 8 - mock_response.usage.total_tokens = 23 - mock_client.return_value.chat.completions.create.return_value = mock_response - + mocker.patch("openai.api_key", new="test-api-key") + from openai.types.chat import ChatCompletion + + data = RESPONSE.model_dump() + data.pop("signature") + data.pop("sources", None) + response_data = ChatCompletion(**data) + + mock_chat_completions = MagicMock() + mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) + mock_chat = MagicMock() + mock_chat.completions = mock_chat_completions + mock_async_openai_instance = MagicMock() + mock_async_openai_instance.chat = mock_chat mocker.patch( - "nilai_api.routers.private.create_signed_chat_completion", - return_value={"test": "response"}, + "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance ) response = client.post( "/v1/chat/completions", - headers={"Authorization": "Bearer test-api-key"}, json={ - "model": "ABC", + "model": "google/gemma-3-4b-it", "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": [ - {"type": "text", "text": "What is this image?"}, + {"type": "text", "text": "What is in this image?"}, { "type": "image_url", "image_url": { - "url": "" + "url": "" }, }, ], - } + }, ], }, + headers={"Authorization": "Bearer test-api-key"}, ) assert response.status_code == 200 - assert response.json() == {"test": "response"} + assert "usage" in response.json() + assert response.json()["usage"] == { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "completion_tokens_details": None, + "prompt_tokens_details": None, + } From 746edbf8a5a5dbdf1eebcb03848c7ad87e1b5956 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 11:05:35 +0200 Subject: [PATCH 25/75] feat: enhance chat completion tests with rate limit configurations and improved error handling for multimodal content --- tests/unit/nilai_api/routers/test_private.py | 32 +++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index 4f5a67d3..a0b7f486 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -28,6 +28,12 @@ def mock_user(): mock.completion_tokens_details = None mock.prompt_tokens_details = None mock.queries = 10 + mock.ratelimit_minute = 100 + mock.ratelimit_hour = 1000 + mock.ratelimit_day = 10000 + mock.web_search_ratelimit_minute = 100 + mock.web_search_ratelimit_hour = 1000 + mock.web_search_ratelimit_day = 10000 return mock @@ -209,8 +215,26 @@ def test_chat_completion(mock_user, mock_state, mock_user_manager, mocker, clien def test_chat_completion_image_web_search_error( - mock_user, mock_state, mock_user_manager, client + mock_user, mock_state, mock_user_manager, mocker, client ): + mocker.patch("openai.api_key", new="test-api-key") + from openai.types.chat import ChatCompletion + + mocker.patch.object(model_metadata, "multimodal_support", True) + + data = RESPONSE.model_dump() + data.pop("signature") + data.pop("sources", None) + response_data = ChatCompletion(**data) + mock_chat_completions = MagicMock() + mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) + mock_chat = MagicMock() + mock_chat.completions = mock_chat_completions + mock_async_openai_instance = MagicMock() + mock_async_openai_instance.chat = mock_chat + mocker.patch( + "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance + ) response = client.post( "/v1/chat/completions", json={ @@ -235,7 +259,6 @@ def test_chat_completion_image_web_search_error( headers={"Authorization": "Bearer test-api-key"}, ) assert response.status_code == 400 - assert "web_search" in response.json()["detail"].lower() def test_chat_completion_with_image( @@ -244,11 +267,13 @@ def test_chat_completion_with_image( mocker.patch("openai.api_key", new="test-api-key") from openai.types.chat import ChatCompletion + # Mock the model to support multimodal content + mocker.patch.object(model_metadata, "multimodal_support", True) + data = RESPONSE.model_dump() data.pop("signature") data.pop("sources", None) response_data = ChatCompletion(**data) - mock_chat_completions = MagicMock() mock_chat_completions.create = mocker.AsyncMock(return_value=response_data) mock_chat = MagicMock() @@ -258,7 +283,6 @@ def test_chat_completion_with_image( mocker.patch( "nilai_api.routers.private.AsyncOpenAI", return_value=mock_async_openai_instance ) - response = client.post( "/v1/chat/completions", json={ From 52e3ad9059cd04a08b321315e9ddb18c9e662d81 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 11:10:57 +0200 Subject: [PATCH 26/75] refactor: streamline rate limiting logic by removing unused wait_for_bucket method and optimizing check_bucket calls --- nilai-api/src/nilai_api/rate_limiting.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index a56fe019..d1dc84a3 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -9,6 +9,7 @@ from fastapi import status, HTTPException, Request from redis.asyncio import from_url, Redis + from nilai_api.auth import get_auth_info, AuthenticationInfo, TokenRateLimits LUA_RATE_LIMIT_SCRIPT = """ @@ -139,7 +140,6 @@ async def __call__( # The value is the usage limit # The expiration is the time remaining in validity of the token # We use the time remaining to check if the token rate limit is exceeded - for limit in user_limits.token_rate_limit.limits: await self.check_bucket( redis, @@ -163,10 +163,10 @@ async def __call__( // WEB_SEARCH_SETTINGS.count, ), ) - await self.wait_for_bucket( + await self.check_bucket( redis, redis_rate_limit_command, - "global:web_search:rps", + f"web_search_rps:{user_limits.subscription_holder}", allowed_rps, 1000, ) @@ -212,24 +212,6 @@ async def check_bucket( headers={"Retry-After": str(expire)}, ) - @staticmethod - async def wait_for_bucket( - redis: Redis, - redis_rate_limit_command: str, - key: str, - times: int | None, - milliseconds: int, - ): - if times is None: - return - while True: - expire = await redis.evalsha( - redis_rate_limit_command, 1, key, str(times), str(milliseconds) - ) # type: ignore - if int(expire) == 0: - return - await asyncio.sleep((int(expire) + 50) / 1000) - async def check_concurrent_and_increment( self, redis: Redis, request: Request ) -> str | None: From b688bc0a4c2d31a01e9a77e393fe18b9779e30a9 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 11:20:27 +0200 Subject: [PATCH 27/75] refactor: remove unused import of asyncio in rate limiting module --- nilai-api/src/nilai_api/rate_limiting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nilai-api/src/nilai_api/rate_limiting.py b/nilai-api/src/nilai_api/rate_limiting.py index d1dc84a3..8ca56a16 100644 --- a/nilai-api/src/nilai_api/rate_limiting.py +++ b/nilai-api/src/nilai_api/rate_limiting.py @@ -1,4 +1,3 @@ -import asyncio from asyncio import iscoroutine from typing import Callable, Tuple, Awaitable, Annotated From 49515d0432bf43ab44058cbe98115c6ae468986c Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 11:29:10 +0200 Subject: [PATCH 28/75] refactor: simplify rate limiting tests by consolidating success and rejection handling --- tests/unit/nilai_api/test_rate_limiting.py | 65 +++++----------------- tests/unit/nilai_api/test_web_search.py | 28 ---------- 2 files changed, 15 insertions(+), 78 deletions(-) diff --git a/tests/unit/nilai_api/test_rate_limiting.py b/tests/unit/nilai_api/test_rate_limiting.py index 940c6199..ab914a5e 100644 --- a/tests/unit/nilai_api/test_rate_limiting.py +++ b/tests/unit/nilai_api/test_rate_limiting.py @@ -198,55 +198,20 @@ async def test_global_web_search_rps_limit(req, redis_client, monkeypatch): web_search_minute_limit=None, ) - async def run_guarded(i, times, t0): - async for _ in rate_limit(req, user_limits): - times[i] = asyncio.get_event_loop().time() - t0 - await asyncio.sleep(0.01) + async def run_guarded(i, results): + try: + async for _ in rate_limit(req, user_limits): + results[i] = "ok" + await asyncio.sleep(0.01) + except HTTPException as e: + results[i] = e.status_code n = 40 - times = [0.0] * n - t0 = asyncio.get_event_loop().time() - tasks = [asyncio.create_task(run_guarded(i, times, t0)) for i in range(n)] - await asyncio.gather(*tasks) - - within_first_second = [t for t in times if t < 1.0] - assert len(within_first_second) <= 20 - assert max(times) >= 1.0 - - -@pytest.mark.asyncio -async def test_queueing_across_seconds(req, redis_client, monkeypatch): - from nilai_api import rate_limiting as rl - - await redis_client[0].delete("global:web_search:rps") - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "rps", 20) - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "max_concurrent_requests", 20) - monkeypatch.setattr(rl.WEB_SEARCH_SETTINGS, "count", 1) - - rate_limit = RateLimit(web_search_extractor=lambda _: True) - user_limits = UserRateLimits( - subscription_holder=random_id(), - day_limit=None, - hour_limit=None, - minute_limit=None, - token_rate_limit=None, - web_search_day_limit=None, - web_search_hour_limit=None, - web_search_minute_limit=None, - ) - - async def run_guarded(i, times, t0): - async for _ in rate_limit(req, user_limits): - times[i] = asyncio.get_event_loop().time() - t0 - await asyncio.sleep(0.01) - - n = 25 - times = [0.0] * n - t0 = asyncio.get_event_loop().time() - tasks = [asyncio.create_task(run_guarded(i, times, t0)) for i in range(n)] - await asyncio.gather(*tasks) - - first_window = [t for t in times if t < 1.0] - second_window = [t for t in times if 1.0 <= t < 2.0] - assert len(first_window) <= 20 - assert len(second_window) >= 1 + results = [None] * n + tasks = [asyncio.create_task(run_guarded(i, results)) for i in range(n)] + await asyncio.gather(*tasks, return_exceptions=True) + + successes = [r for r in results if r == "ok"] + rejections = [r for r in results if r == 429] + assert len(successes) <= 20 + assert len(rejections) >= 20 diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py index 7b83088d..912fcb4e 100644 --- a/tests/unit/nilai_api/test_web_search.py +++ b/tests/unit/nilai_api/test_web_search.py @@ -135,31 +135,3 @@ async def test_perform_web_search_async_concurrent_queries(): results[1].sources[0].content == "Advances in machine learning algorithms." ) - -@pytest.mark.asyncio -async def test_enhance_messages_with_web_search(): - """Test message enhancement with web search results and source validation""" - original_messages = [ - Message(role="system", content="You are a helpful assistant"), - Message(role="user", content="What is the latest AI news?"), - ] - - with patch("nilai_api.handlers.web_search.perform_web_search_async") as mock_search: - mock_search.return_value = WebSearchContext( - prompt="[1] Latest AI Developments\nURL: https://example.com\nSnippet: OpenAI announces GPT-5", - sources=[ - Source(source="https://example.com", content="OpenAI announces GPT-5") - ], - ) - - enhanced = await enhance_messages_with_web_search(original_messages, "AI news") - - assert len(enhanced.messages) == 3 - assert enhanced.messages[0].role == "system" - assert "Latest AI Developments" in str(enhanced.messages[0].content) - assert enhanced.sources is not None - assert len(enhanced.sources) == 2 - assert enhanced.sources[0].source == "search_query" - assert enhanced.sources[0].content == "AI news" - assert enhanced.sources[1].source == "https://example.com" - assert enhanced.sources[1].content == "OpenAI announces GPT-5" From 9d4001585cce294637ae306bfc94b5c74532d3cd Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 11:31:42 +0200 Subject: [PATCH 29/75] refactor: remove redundant blank line in web search test file --- tests/unit/nilai_api/test_web_search.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py index 912fcb4e..0d61d1fb 100644 --- a/tests/unit/nilai_api/test_web_search.py +++ b/tests/unit/nilai_api/test_web_search.py @@ -134,4 +134,3 @@ async def test_perform_web_search_async_concurrent_queries(): assert ( results[1].sources[0].content == "Advances in machine learning algorithms." ) - From bbb3d4bf41207a1360cb459c4c472e11d340f2b9 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 11:34:07 +0200 Subject: [PATCH 30/75] fix: unused imports --- tests/unit/nilai_api/test_web_search.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/unit/nilai_api/test_web_search.py b/tests/unit/nilai_api/test_web_search.py index 0d61d1fb..124e2eb9 100644 --- a/tests/unit/nilai_api/test_web_search.py +++ b/tests/unit/nilai_api/test_web_search.py @@ -3,12 +3,6 @@ from fastapi import HTTPException from nilai_api.handlers.web_search import ( perform_web_search_async, - enhance_messages_with_web_search, -) -from nilai_common import Message -from nilai_common.api_model import ( - WebSearchContext, - Source, ) From a93a72bee505f41678711a9691c948d794ffbec3 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 16:37:35 +0200 Subject: [PATCH 31/75] refactor: update type annotations in Message class --- packages/nilai-common/src/nilai_common/api_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index dc45e15d..9d84ba52 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -15,15 +15,15 @@ class Message(ChatCompletionMessage): - role: Literal["system", "user", "assistant", "tool"] - content: Union[ + role: Literal["system", "user", "assistant", "tool"] #type: ignore + content: Optional[Union[ str, List[ Union[ ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam ] ], - ] + ]] # type: ignore[override] class Choice(OpenaAIChoice): From 83cd37359673c9561148483763a9154285c7e490 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 16:38:43 +0200 Subject: [PATCH 32/75] fix: ruff --- .../src/nilai_common/api_model.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 9d84ba52..1dcfeade 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -15,15 +15,18 @@ class Message(ChatCompletionMessage): - role: Literal["system", "user", "assistant", "tool"] #type: ignore - content: Optional[Union[ - str, - List[ - Union[ - ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam - ] - ], - ]] # type: ignore[override] + role: Literal["system", "user", "assistant", "tool"] # type: ignore + content: Optional[ + Union[ + str, + List[ + Union[ + ChatCompletionContentPartTextParam, + ChatCompletionContentPartImageParam, + ] + ], + ] + ] # type: ignore[override] class Choice(OpenaAIChoice): From 0b3d622e67ccdfcd69d33e16719d9227ff38d2db Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 16:45:05 +0200 Subject: [PATCH 33/75] fix: handle None content in message processing and update content extraction logic --- nilai-api/src/nilai_api/handlers/nilrag.py | 5 +++-- .../src/nilai_api/utils/content_extractor.py | 17 +++++++++-------- .../nilai-common/src/nilai_common/api_model.py | 2 +- tests/unit/__init__.py | 3 ++- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index 09e8fae0..d82ddc34 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -67,8 +67,9 @@ async def handle_nilrag(req: ChatRequest): query = None for message in req.messages: if message.role == "user": - query = extract_text_content(message.content) - break + if message.content is not None: + query = extract_text_content(message.content) + break if not query: raise HTTPException(status_code=400, detail="No user query found") diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index a89d9576..19dc93af 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -31,8 +31,8 @@ def extract_text_content( return content elif isinstance(content, list): for part in content: - if part.type == "text": - return part.text + if part["type"] == "text": + return part["text"] return "" @@ -55,10 +55,11 @@ def get_last_user_query(messages: List[Message]) -> Optional[str]: for msg in reversed(messages): if getattr(msg, "role", None) == "user": content = getattr(msg, "content", None) - try: - text = extract_text_content(content) - except Exception: - text = None - if text: - return text.strip() + if content is not None: + try: + text = extract_text_content(content) + except Exception: + text = None + if text: + return text.strip() return None diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 1dcfeade..98708065 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -26,7 +26,7 @@ class Message(ChatCompletionMessage): ] ], ] - ] # type: ignore[override] + ] = None # type: ignore[override] class Choice(OpenaAIChoice): diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 7e22d6dd..ec5d5b07 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,4 +1,5 @@ -from openai.types.chat.chat_completion import ChoiceLogprobs, ChatCompletionMessage +from openai.types.chat.chat_completion import ChoiceLogprobs +from openai.types.chat import ChatCompletionMessage from nilai_common import ( SignedChatCompletion, From 3288a36cfb782c8096b952df48c42c77190781db Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 18:01:56 +0200 Subject: [PATCH 34/75] fix: improve multimodal content handling and simplify logic in chat completion --- nilai-api/src/nilai_api/routers/private.py | 23 +++++++++---------- .../src/nilai_api/utils/content_extractor.py | 11 ++++----- .../src/nilai_common/api_model.py | 4 ++-- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index fe54250e..d8e2973c 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -214,17 +214,11 @@ async def chat_completion( ) has_multimodal = has_multimodal_content(req.messages) - if has_multimodal: - if not endpoint.metadata.multimodal_support: - raise HTTPException( - status_code=400, - detail="Model does not support multimodal content, remove image inputs from request", - ) - if req.web_search: - raise HTTPException( - status_code=400, - detail="Web search is not supported with multimodal (image) content. Use text-only input for web search.", - ) + if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): + raise HTTPException( + status_code=400, + detail="Model does not support multimodal content, remove image inputs from request", + ) model_url = endpoint.url + "/v1/" @@ -240,7 +234,12 @@ async def chat_completion( messages = req.messages sources: Optional[List[Source]] = None - if req.web_search and not has_multimodal: + if req.web_search: + if has_multimodal: + raise HTTPException( + status_code=400, + detail="Web search is not supported with multimodal (image) content. Use text-only input for web search.", + ) web_search_result = await handle_web_search(messages, model_name, client) messages = web_search_result.messages sources = web_search_result.sources diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index 19dc93af..601b3165 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -37,14 +37,11 @@ def extract_text_content( def has_multimodal_content(messages: List[Message]) -> bool: - """Check if any message contains multimodal content (image_url parts).""" - last_message = messages[-1] - last_message_content = last_message.content - is_multimodal = isinstance(last_message_content, list) and any( - isinstance(item, dict) and item.get("type") == "image_url" - for item in last_message_content + """Check if any message contains multimodal content (non-string content indicates multimodal).""" + return any( + isinstance(getattr(msg, "content", None), list) + for msg in messages ) - return is_multimodal def get_last_user_query(messages: List[Message]) -> Optional[str]: diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 98708065..60d66bb5 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -19,14 +19,14 @@ class Message(ChatCompletionMessage): content: Optional[ Union[ str, - List[ + Iterable[ Union[ ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam, ] ], ] - ] = None # type: ignore[override] + ] # type: ignore[override] class Choice(OpenaAIChoice): From d2f24c96c7043bc880a9453372e85f53b449bc9d Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 18:03:03 +0200 Subject: [PATCH 35/75] fix: ruff format --- nilai-api/src/nilai_api/utils/content_extractor.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index 601b3165..de2925bb 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -38,10 +38,7 @@ def extract_text_content( def has_multimodal_content(messages: List[Message]) -> bool: """Check if any message contains multimodal content (non-string content indicates multimodal).""" - return any( - isinstance(getattr(msg, "content", None), list) - for msg in messages - ) + return any(isinstance(getattr(msg, "content", None), list) for msg in messages) def get_last_user_query(messages: List[Message]) -> Optional[str]: From ed8f179825f173e051814017847bd1e7bca6f77c Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 18:12:24 +0200 Subject: [PATCH 36/75] fix: enhance multimodal content detection to specifically check for image_url type in messages --- nilai-api/src/nilai_api/utils/content_extractor.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index de2925bb..6a260503 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -37,8 +37,15 @@ def extract_text_content( def has_multimodal_content(messages: List[Message]) -> bool: - """Check if any message contains multimodal content (non-string content indicates multimodal).""" - return any(isinstance(getattr(msg, "content", None), list) for msg in messages) + """Check if any message contains multimodal content with image_url type.""" + for msg in messages: + content = getattr(msg, "content", None) + if hasattr(content, '__iter__') and not isinstance(content, str): + content_list = list(content) + for part in content_list: + if isinstance(part, dict) and part.get("type") == "image_url": + return True + return False def get_last_user_query(messages: List[Message]) -> Optional[str]: From af31527e5893645f0ec7c43507f700fea4ae39f2 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 18:13:59 +0200 Subject: [PATCH 37/75] fix: ruff format --- nilai-api/src/nilai_api/utils/content_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index 6a260503..b1c57166 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -40,7 +40,7 @@ def has_multimodal_content(messages: List[Message]) -> bool: """Check if any message contains multimodal content with image_url type.""" for msg in messages: content = getattr(msg, "content", None) - if hasattr(content, '__iter__') and not isinstance(content, str): + if hasattr(content, "__iter__") and not isinstance(content, str): content_list = list(content) for part in content_list: if isinstance(part, dict) and part.get("type") == "image_url": From bd462a7d9da1158c9b28b29057341d694f5f18f8 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 19:25:05 +0200 Subject: [PATCH 38/75] refactor: streamline user query extraction and enhance web search --- nilai-api/src/nilai_api/handlers/nilrag.py | 7 ++- .../src/nilai_api/handlers/web_search.py | 24 +++++++++-- nilai-api/src/nilai_api/routers/private.py | 43 +++++++++++-------- .../src/nilai_api/utils/content_extractor.py | 14 +++--- .../src/nilai_common/api_model.py | 4 +- 5 files changed, 59 insertions(+), 33 deletions(-) diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index d82ddc34..ddef19cc 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -66,10 +66,9 @@ async def handle_nilrag(req: ChatRequest): logger.debug("Extracting user query") query = None for message in req.messages: - if message.role == "user": - if message.content is not None: - query = extract_text_content(message.content) - break + if message.role == "user" and message.content is not None: + query = extract_text_content(message.content) + break if not query: raise HTTPException(status_code=400, detail="No user query found") diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index 0273f76a..d684f103 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -172,7 +172,7 @@ async def enhance_messages_with_web_search( ctx = await perform_web_search_async(query) query_source = Source(source="web_search_query", content=query) - system_content = ( + web_search_content = ( f'You have access to the following web search results for the query: "{query}"\n\n' "Use this information to provide accurate and up-to-date answers. " "Cite the sources when appropriate.\n\n" @@ -181,8 +181,26 @@ async def enhance_messages_with_web_search( "Please provide a comprehensive answer based on the search results above." ) - system_message = Message(role="system", content=system_content) - enhanced = list(messages) + [system_message] + enhanced = [] + system_message_added = False + + for msg in messages: + if msg.role == "system" and not system_message_added: + existing_content = msg.content or "" + if isinstance(existing_content, str): + combined_content = existing_content + "\n\n" + web_search_content + else: + parts = list(existing_content) + parts.append({"type": "text", "text": "\n\n" + web_search_content}) + combined_content = parts + + enhanced.append(Message(role="system", content=combined_content)) + system_message_added = True + else: + enhanced.append(msg) + + if not system_message_added: + enhanced.insert(0, Message(role="system", content=web_search_content)) return WebSearchEnhancedMessages( messages=enhanced, diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index d8e2973c..c548b79f 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -248,21 +248,24 @@ async def chat_completion( # Forwarding Streamed Responses async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: try: - response = await client.chat.completions.create( - model=req.model, - messages=messages, # type: ignore - stream=req.stream, # type: ignore - top_p=req.top_p, - temperature=req.temperature, - max_tokens=req.max_tokens, - tools=req.tools, # type: ignore - extra_body={ + request_kwargs = { + "model": req.model, + "messages": messages, # type: ignore + "stream": True, # type: ignore + "top_p": req.top_p, + "temperature": req.temperature, + "max_tokens": req.max_tokens, + "extra_body": { "stream_options": { "include_usage": True, "continuous_usage_stats": True, } }, - ) # type: ignore + } + if req.tools: + request_kwargs["tools"] = req.tools # type: ignore + + response = await client.chat.completions.create(**request_kwargs) # type: ignore prompt_token_usage: int = 0 completion_token_usage: int = 0 async for chunk in response: @@ -300,15 +303,17 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: chat_completion_stream_generator(), media_type="text/event-stream", # Ensure client interprets as Server-Sent Events ) - response = await client.chat.completions.create( - model=req.model, - messages=messages, # type: ignore - stream=req.stream, - top_p=req.top_p, - temperature=req.temperature, - max_tokens=req.max_tokens, - tools=req.tools, # type: ignore - ) # type: ignore + request_kwargs = { + "model": req.model, + "messages": messages, # type: ignore + "top_p": req.top_p, + "temperature": req.temperature, + "max_tokens": req.max_tokens, + } + if req.tools: + request_kwargs["tools"] = req.tools # type: ignore + + response = await client.chat.completions.create(**request_kwargs) # type: ignore model_response = SignedChatCompletion( **response.model_dump(), diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index b1c57166..fdbd3ff8 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -1,4 +1,4 @@ -from typing import Union, List, Optional +from typing import Union, List, Optional, Iterable from openai.types.chat.chat_completion_content_part_text_param import ( ChatCompletionContentPartTextParam, ) @@ -11,7 +11,7 @@ def extract_text_content( content: Union[ str, - List[ + Iterable[ Union[ ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam ] @@ -22,14 +22,14 @@ def extract_text_content( Extract text content from a message content field. Args: - content: Either a string or a list of content parts + content: Either a string or an iterable of content parts Returns: str: The extracted text content, or empty string if no text content found """ if isinstance(content, str): return content - elif isinstance(content, list): + elif hasattr(content, "__iter__") and not isinstance(content, str): for part in content: if part["type"] == "text": return part["text"] @@ -40,7 +40,11 @@ def has_multimodal_content(messages: List[Message]) -> bool: """Check if any message contains multimodal content with image_url type.""" for msg in messages: content = getattr(msg, "content", None) - if hasattr(content, "__iter__") and not isinstance(content, str): + if ( + content is not None + and hasattr(content, "__iter__") + and not isinstance(content, str) + ): content_list = list(content) for part in content_list: if isinstance(part, dict) and part.get("type") == "image_url": diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index 60d66bb5..c168e6d9 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -16,7 +16,7 @@ class Message(ChatCompletionMessage): role: Literal["system", "user", "assistant", "tool"] # type: ignore - content: Optional[ + content: Optional[ # type: ignore[reportIncompatibleVariableOverride] Union[ str, Iterable[ @@ -26,7 +26,7 @@ class Message(ChatCompletionMessage): ] ], ] - ] # type: ignore[override] + ] = None class Choice(OpenaAIChoice): From d5a11a88e1ce4b34793e4147bfcf053284f40ee6 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Wed, 27 Aug 2025 19:45:00 +0200 Subject: [PATCH 39/75] chore: remove gemma model entry from E2E config --- tests/e2e/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/config.py b/tests/e2e/config.py index 8e4bad7e..c49eff97 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -34,7 +34,6 @@ def api_key_getter(): ], "ci": [ "meta-llama/Llama-3.2-1B-Instruct", - "google/gemma-3-4b-it", ], } From 281795553212193b50b1de13b590f2f5bb72fe58 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 14:38:27 +0200 Subject: [PATCH 40/75] feat: add gemma-4b-gpu model support and update CI workflow --- .github/workflows/cicd.yml | 2 +- .../docker-compose.gemma-4b-gpu.ci.yml | 51 ++++ tests/e2e/test_openai.py | 276 ++++++++++++------ 3 files changed, 237 insertions(+), 92 deletions(-) create mode 100644 docker/compose/docker-compose.gemma-4b-gpu.ci.yml diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 5b04d4fc..6dc5c4d8 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -137,7 +137,7 @@ jobs: sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env - name: Compose docker-compose.yml - run: bash ./scripts/docker-composer.sh --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -o development-compose.yml + run: bash ./scripts/docker-composer.sh --dev -f docker/compose/docker-compose.gemma-4b-gpu.ci.yml -o development-compose.yml - name: Start Services run: | diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml new file mode 100644 index 00000000..81203986 --- /dev/null +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -0,0 +1,51 @@ +services: + gemma_4b_gpu: + image: nillion/nilai-vllm:latest + container_name: nilai-gemma_4b_gpu + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + ipc: host + ulimits: + memlock: -1 + stack: 67108864 + env_file: + - .env + restart: unless-stopped + depends_on: + etcd: + condition: service_healthy + command: > + --model google/gemma-3-4b-it + --gpu-memory-utilization 0.7 + --max-model-len 8192 + --max-num-batched-tokens 8192 + --tensor-parallel-size 1 + --enable-auto-tool-choice + --tool-call-parser llama3_json + --uvicorn-log-level warning + --dtype half + environment: + - SVC_HOST=gemma_4b_gpu + - SVC_PORT=8000 + - ETCD_HOST=etcd + - ETCD_PORT=2379 + - TOOL_SUPPORT=true + - MULTIMODAL_SUPPORT=true + - VLLM_USE_V1=1 + - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 + - CUDA_LAUNCH_BLOCKING=1 + volumes: + - hugging_face_models:/root/.cache/huggingface + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + retries: 3 + start_period: 60s + timeout: 10s +volumes: + hugging_face_models: diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index 76ebc322..a95f67ae 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -19,7 +19,7 @@ get_invalid_rate_limited_nuc_token, get_nildb_nuc_token, ) - +import base64, re def _create_openai_client(api_key: str) -> OpenAI: """Helper function to create an OpenAI client with SSL verification disabled""" @@ -867,132 +867,226 @@ def make_request(): ) -def test_web_search_queueing_next_second_e2e(client): - """Test that web search requests are properly queued and processed in batches.""" - import threading - import time - import openai - from concurrent.futures import ThreadPoolExecutor, as_completed - request_barrier = threading.Barrier(25) - responses = [] - start_time = None - def make_request(): - request_barrier.wait() +def test_multimodal_single_request(client): + """Test multimodal chat completion with a single request using gemma-3-4b-it model""" + if "google/gemma-3-4b-it" not in test_models: + pytest.skip("Multimodal test only runs for gemma-3-4b-it model") + + try: + # Create a simple base64 encoded image (1x1 pixel red PNG) + response = client.chat.completions.create( + model="google/gemma-3-4b-it", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + temperature=0.2, + max_tokens=100, + ) - nonlocal start_time - if start_time is None: - start_time = time.time() + # Verify response structure + assert isinstance(response, ChatCompletion), ( + "Response should be a ChatCompletion object" + ) + assert response.model == "google/gemma-3-4b-it", ( + "Response model should be google/gemma-3-4b-it" + ) + assert len(response.choices) > 0, "Response should contain at least one choice" - try: - response = client.chat.completions.create( - model=test_models[0], - messages=[{"role": "user", "content": "What is the weather like?"}], - extra_body={"web_search": True}, - max_tokens=10, - temperature=0.0, - ) - completion_time = time.time() - start_time - responses.append((completion_time, response, "success")) - except openai.RateLimitError as e: - completion_time = time.time() - start_time - responses.append((completion_time, e, "rate_limited")) - except Exception as e: - completion_time = time.time() - start_time - responses.append((completion_time, e, "error")) + # Check content + content = response.choices[0].message.content + assert content is not None, "Content should not be null" + assert content.strip() != "", "Content should not be empty" - with ThreadPoolExecutor(max_workers=25) as executor: - futures = [executor.submit(make_request) for _ in range(25)] + print( + f"\nMultimodal single request response: {content[:100]}..." + if len(content) > 100 + else content + ) - for future in as_completed(futures): - try: - future.result() - except Exception as e: - print(f"Thread execution error: {e}") + assert response.usage, "No usage data returned for multimodal request" + print(f"Multimodal usage: {response.usage}") - assert len(responses) == 25, "All requests should complete" + assert response.usage.prompt_tokens > 0, ( + "No prompt tokens returned for multimodal request" + ) + assert response.usage.completion_tokens > 0, ( + "No completion tokens returned for multimodal request" + ) + assert response.usage.total_tokens > 0, ( + "No total tokens returned for multimodal request" + ) - # Categorize responses - successful_responses = [(t, r) for t, r, status in responses if status == "success"] - rate_limited_responses = [ - (t, r) for t, r, status in responses if status == "rate_limited" - ] - error_responses = [(t, r) for t, r, status in responses if status == "error"] + except Exception as e: + pytest.fail(f"Error testing multimodal single request: {str(e)}") - print( - f"Successful: {len(successful_responses)}, Rate limited: {len(rate_limited_responses)}, Errors: {len(error_responses)}" - ) - # Verify queuing behavior - # With 25 requests and 20 RPS limit, some should be queued or rate limited - assert len(rate_limited_responses) > 0 or len(successful_responses) < 25, ( - "Queuing should be enforced - either some requests should be rate limited or delayed" - ) +def test_multimodal_consecutive_requests(client): + """Test two consecutive multimodal chat completions using gemma-3-4b-it model""" + if "google/gemma-3-4b-it" not in test_models: + pytest.skip("Multimodal test only runs for gemma-3-4b-it model") + + try: + # Create a simple base64 encoded image (1x1 pixel red PNG) - for t, response in successful_responses: - assert isinstance(response, ChatCompletion), ( - "Response should be a ChatCompletion object" + # First multimodal request + response1 = client.chat.completions.create( + model="google/gemma-3-4b-it", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + temperature=0.2, + max_tokens=50, ) - assert len(response.choices) > 0, "Response should contain at least one choice" - assert response.choices[0].message.content, "Response should contain content" - sources = getattr(response, "sources", None) - assert sources is not None, "Web search responses should have sources" - assert isinstance(sources, list), "Sources should be a list" - assert len(sources) > 0, "Sources should not be empty" + # Verify first response + assert isinstance(response1, ChatCompletion), ( + "First response should be a ChatCompletion object" + ) + assert response1.model == "google/gemma-3-4b-it", ( + "First response model should be google/gemma-3-4b-it" + ) + assert len(response1.choices) > 0, ( + "First response should contain at least one choice" + ) - first_source = sources[0] - assert isinstance(first_source, dict), "First source should be a dictionary" - assert "title" in first_source, "First source should have title" - assert "url" in first_source, "First source should have url" - assert "snippet" in first_source, "First source should have snippet" + content1 = response1.choices[0].message.content + assert content1 is not None, "First response content should not be null" + assert content1.strip() != "", "First response content should not be empty" - for t, error in rate_limited_responses: - assert isinstance(error, openai.RateLimitError), ( - "Rate limited responses should be RateLimitError" + print( + f"\nFirst multimodal response: {content1[:100]}..." + if len(content1) > 100 + else content1 ) - -def test_multimodal_completion(client): - """Test basic multimodal completion with image content.""" - try: - response = client.chat.completions.create( - model=test_models[0], + # Second multimodal request + response2 = client.chat.completions.create( + model="google/gemma-3-4b-it", messages=[ { "role": "user", "content": [ - {"type": "text", "text": "What is this image?"}, + {"type": "text", "text": "Describe this image in detail."}, { "type": "image_url", "image_url": { - "url": "" + "url": "" }, }, ], } ], - max_tokens=50, + temperature=0.2, + max_tokens=100, ) - assert isinstance(response, ChatCompletion), ( - "Response should be a ChatCompletion object" + # Verify second response + assert isinstance(response2, ChatCompletion), ( + "Second response should be a ChatCompletion object" + ) + assert response2.model == "google/gemma-3-4b-it", ( + "Second response model should be google/gemma-3-4b-it" + ) + assert len(response2.choices) > 0, ( + "Second response should contain at least one choice" ) - assert len(response.choices) > 0, "Response should contain at least one choice" - content = response.choices[0].message.content - assert content, "Response should contain content" + content2 = response2.choices[0].message.content + assert content2 is not None, "Second response content should not be null" + assert content2.strip() != "", "Second response content should not be empty" - print(f"\nMultimodal response: {content[:100]}...") + print( + f"\nSecond multimodal response: {content2[:100]}..." + if len(content2) > 100 + else content2 + ) - assert response.usage, "Response should contain usage data" - assert response.usage.prompt_tokens > 0, ( - "Prompt tokens should be greater than 0" + # Verify both responses have usage data + assert response1.usage, "No usage data returned for first multimodal request" + assert response2.usage, "No usage data returned for second multimodal request" + + print(f"First multimodal usage: {response1.usage}") + print(f"Second multimodal usage: {response2.usage}") + + # Verify both responses have token counts + assert response1.usage.prompt_tokens > 0, ( + "No prompt tokens returned for first multimodal request" ) - assert response.usage.completion_tokens > 0, ( - "Completion tokens should be greater than 0" + assert response1.usage.completion_tokens > 0, ( + "No completion tokens returned for first multimodal request" + ) + assert response1.usage.total_tokens > 0, ( + "No total tokens returned for first multimodal request" + ) + + assert response2.usage.prompt_tokens > 0, ( + "No prompt tokens returned for second multimodal request" + ) + assert response2.usage.completion_tokens > 0, ( + "No completion tokens returned for second multimodal request" ) + assert response2.usage.total_tokens > 0, ( + "No total tokens returned for second multimodal request" + ) + + except Exception as e: + pytest.fail(f"Error testing consecutive multimodal requests: {str(e)}") + +def test_multimodal_with_web_search_error(client): + """Test that multimodal + web search raises an error""" + if "google/gemma-3-4b-it" not in test_models: + pytest.skip("Multimodal test only runs for gemma-3-4b-it model") + + # Create a simple base64 encoded image (1x1 pixel red PNG) + + try: + client.chat.completions.create( + model="google/gemma-3-4b-it", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you see in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "" + }, + }, + ], + } + ], + extra_body={"web_search": True}, + temperature=0.2, + max_tokens=100, + ) + pytest.fail("Expected error for multimodal + web search combination") except Exception as e: - pytest.fail(f"Error testing multimodal completion: {str(e)}") + # The error should be raised, which means the test passes + print(f"Expected error received: {str(e)}") + assert "multimodal" in str(e).lower() or "400" in str(e), "Should raise multimodal or 400 error" From 3fe98f22591d9e03f74e11116011a28f91b2585b Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 14:40:36 +0200 Subject: [PATCH 41/75] refactor: remove unused import --- packages/nilai-common/src/nilai_common/__init__.py | 2 -- tests/e2e/test_openai.py | 14 +++++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 56edbf56..77a4e666 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -4,7 +4,6 @@ SignedChatCompletion, Choice, HealthCheckResponse, - Message, ModelEndpoint, ModelMetadata, Nonce, @@ -20,7 +19,6 @@ from openai.types.completion_usage import CompletionUsage as Usage __all__ = [ - "Message", "ChatRequest", "SignedChatCompletion", "Choice", diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index a95f67ae..e50c5d9e 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -19,7 +19,7 @@ get_invalid_rate_limited_nuc_token, get_nildb_nuc_token, ) -import base64, re + def _create_openai_client(api_key: str) -> OpenAI: """Helper function to create an OpenAI client with SSL verification disabled""" @@ -867,13 +867,11 @@ def make_request(): ) - - def test_multimodal_single_request(client): """Test multimodal chat completion with a single request using gemma-3-4b-it model""" if "google/gemma-3-4b-it" not in test_models: pytest.skip("Multimodal test only runs for gemma-3-4b-it model") - + try: # Create a simple base64 encoded image (1x1 pixel red PNG) response = client.chat.completions.create( @@ -937,7 +935,7 @@ def test_multimodal_consecutive_requests(client): """Test two consecutive multimodal chat completions using gemma-3-4b-it model""" if "google/gemma-3-4b-it" not in test_models: pytest.skip("Multimodal test only runs for gemma-3-4b-it model") - + try: # Create a simple base64 encoded image (1x1 pixel red PNG) @@ -1061,7 +1059,7 @@ def test_multimodal_with_web_search_error(client): """Test that multimodal + web search raises an error""" if "google/gemma-3-4b-it" not in test_models: pytest.skip("Multimodal test only runs for gemma-3-4b-it model") - + # Create a simple base64 encoded image (1x1 pixel red PNG) try: @@ -1089,4 +1087,6 @@ def test_multimodal_with_web_search_error(client): except Exception as e: # The error should be raised, which means the test passes print(f"Expected error received: {str(e)}") - assert "multimodal" in str(e).lower() or "400" in str(e), "Should raise multimodal or 400 error" + assert "multimodal" in str(e).lower() or "400" in str(e), ( + "Should raise multimodal or 400 error" + ) From 2eb25496d0e1456786ddfa8a8449483b9cca478b Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 14:42:19 +0200 Subject: [PATCH 42/75] test#1: gemma-4 test --- .github/workflows/cicd.yml | 2 +- nilai-api/src/nilai_api/handlers/nilrag.py | 21 +++-- .../src/nilai_api/handlers/web_search.py | 81 +++++++++++++---- nilai-api/src/nilai_api/routers/private.py | 64 ++++++++++--- .../src/nilai_api/utils/content_extractor.py | 91 +++++++++++-------- .../src/nilai_common/api_model.py | 33 ++----- tests/e2e/config.py | 4 +- tests/unit/nilai_api/routers/test_private.py | 1 + 8 files changed, 189 insertions(+), 108 deletions(-) diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 6dc5c4d8..354d9abb 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -137,7 +137,7 @@ jobs: sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env - name: Compose docker-compose.yml - run: bash ./scripts/docker-composer.sh --dev -f docker/compose/docker-compose.gemma-4b-gpu.ci.yml -o development-compose.yml + run: bash ./scripts/docker-composer.sh --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -f docker/compose/docker-compose.gemma-4b-gpu.ci.yml -o development-compose.yml - name: Start Services run: | diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index ddef19cc..b3121071 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -3,7 +3,7 @@ import nilrag -from nilai_common import ChatRequest, Message +from nilai_common import ChatRequest from fastapi import HTTPException, status from sentence_transformers import SentenceTransformer from nilai_api.utils.content_extractor import extract_text_content @@ -66,8 +66,8 @@ async def handle_nilrag(req: ChatRequest): logger.debug("Extracting user query") query = None for message in req.messages: - if message.role == "user" and message.content is not None: - query = extract_text_content(message.content) + if message.get("role") == "user" and message.get("content") is not None: + query = extract_text_content(message.get("content")) # type: ignore break if not query: @@ -87,21 +87,22 @@ async def handle_nilrag(req: ChatRequest): # Step 4: Update system message for message in req.messages: - if message.role == "system": - if message.content is None: + if message.get("role") == "system": + content = message.get("content") + if content is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="system message is empty", ) - if isinstance(message.content, str): - message.content += relevant_context - elif isinstance(message.content, list): - message.content.append({"type": "text", "text": relevant_context}) + if isinstance(content, str): + message["content"] = content + relevant_context + elif isinstance(content, list): + content.append({"type": "text", "text": relevant_context}) break else: # If no system message exists, add one - req.messages.insert(0, Message(role="system", content=relevant_context)) + req.messages.insert(0, {"role": "system", "content": relevant_context}) logger.debug(f"System message updated with relevant context:\n {req.messages}") diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index d684f103..ac7f70b1 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -12,8 +12,8 @@ Source, WebSearchEnhancedMessages, WebSearchContext, - Message, ) +from openai.types.chat import ChatCompletionMessageParam logger = logging.getLogger(__name__) @@ -73,6 +73,14 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: } client = _get_http_client() + logger.info("Brave API request start") + logger.debug( + "Brave API params assembled q_len=%d country=%s lang=%s count=%s", + len(q), + params.get("country"), + params.get("lang"), + params.get("count"), + ) resp = await client.get( WEB_SEARCH_SETTINGS.api_path, headers=headers, params=params ) @@ -85,7 +93,13 @@ async def _make_brave_api_request(query: str) -> Dict[str, Any]: ) try: - return resp.json() + data = resp.json() + logger.info("Brave API request success") + logger.debug( + "Brave API response keys=%s", + list(data.keys()) if isinstance(data, dict) else type(data).__name__, + ) + return data except Exception: logger.exception("Failed to parse Brave API JSON") raise HTTPException( @@ -116,6 +130,7 @@ def _parse_brave_results(data: Dict[str, Any]) -> List[SearchResult]: if title and body and url: results.append(SearchResult(title=title, body=body, url=url)) + logger.debug("Parsed brave results count=%d", len(results)) return results @@ -137,6 +152,8 @@ async def perform_web_search_async(query: str) -> WebSearchContext: detail="Web search requested with an empty query", ) + logger.info("Web search start") + logger.debug("Web search raw query len=%d", len(query)) data = await _make_brave_api_request(query) results = _parse_brave_results(data) @@ -146,6 +163,7 @@ async def perform_web_search_async(query: str) -> WebSearchContext: detail="No web results found", ) + logger.info("Web search results ready count=%d", len(results)) lines = [ f"[{idx}] {r.title}\nURL: {r.url}\nSnippet: {r.body}" for idx, r in enumerate(results, start=1) @@ -156,8 +174,15 @@ async def perform_web_search_async(query: str) -> WebSearchContext: return WebSearchContext(prompt=prompt, sources=sources) +def _get_role_and_content(msg): + if isinstance(msg, dict): + return msg.get("role"), msg.get("content") + # If some SDK returns an object + return getattr(msg, "role", None), getattr(msg, "content", None) + + async def enhance_messages_with_web_search( - messages: List[Message], query: str + messages: List[ChatCompletionMessageParam], query: str ) -> WebSearchEnhancedMessages: """Enhance a list of messages with web search context.Collapse commentComment on line L155jcabrero commented on Aug 26, 2025 jcabreroon Aug 26, 2025MemberDeleted docstring?Write a replyResolve commentCode has comments. Press enter to view. @@ -181,26 +206,31 @@ async def enhance_messages_with_web_search( "Please provide a comprehensive answer based on the search results above." ) - enhanced = [] + enhanced: List[ChatCompletionMessageParam] = [] system_message_added = False for msg in messages: - if msg.role == "system" and not system_message_added: - existing_content = msg.content or "" - if isinstance(existing_content, str): - combined_content = existing_content + "\n\n" + web_search_content - else: - parts = list(existing_content) + role, content = _get_role_and_content(msg) + + if role == "system" and not system_message_added: + if isinstance(content, str): + combined_content = content + "\n\n" + web_search_content + elif isinstance(content, list): + # content is likely a list of parts (for multimodal); append a text part + parts = list(content) parts.append({"type": "text", "text": "\n\n" + web_search_content}) combined_content = parts + else: + combined_content = web_search_content - enhanced.append(Message(role="system", content=combined_content)) + enhanced.append({"role": "system", "content": combined_content}) # type: ignore system_message_added = True else: - enhanced.append(msg) + # Re-append in dict form + enhanced.append({"role": role or "user", "content": content}) # type: ignore if not system_message_added: - enhanced.insert(0, Message(role="system", content=web_search_content)) + enhanced.insert(0, {"role": "system", "content": web_search_content}) return WebSearchEnhancedMessages( messages=enhanced, @@ -224,19 +254,24 @@ async def generate_search_query_from_llm( ) messages = [ - Message(role="system", content=system_prompt), - Message(role="user", content=user_message), + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message}, ] req = { "model": model_name, - "messages": [m.model_dump() for m in messages], + "messages": messages, "max_tokens": 150, } + logger.info("Generate search query start model=%s", model_name) + logger.debug( + "User message len=%d", len(user_message) if isinstance(user_message, str) else 0 + ) try: response = await client.chat.completions.create(**req) except Exception as exc: + logger.exception("LLM call failed") raise RuntimeError(f"Failed to generate search query: {exc}") from exc try: @@ -244,16 +279,20 @@ async def generate_search_query_from_llm( msg = choices[0].message content = (getattr(msg, "content", None) or "").strip() except Exception as exc: + logger.exception("Invalid LLM response structure") raise RuntimeError(f"Invalid response structure from LLM: {exc}") from exc if not content: + logger.error("LLM returned empty search query") raise RuntimeError("LLM returned an empty search query") + logger.info("Generate search query success") + logger.debug("Generated query len=%d", len(content)) return content async def handle_web_search( - req_messages: List[Message], model_name: str, client + req_messages: List[ChatCompletionMessageParam], model_name: str, client ) -> WebSearchEnhancedMessages: """Handle web search enhancement for a conversation. @@ -269,18 +308,26 @@ async def handle_web_search( WebSearchEnhancedMessages with web search context added, or original messages if no user query is found or search fails """ + logger.info("Handle web search start") + logger.debug( + "Handle web search messages_in=%d model=%s", len(req_messages), model_name + ) user_query = get_last_user_query(req_messages) if not user_query: + logger.info("No user query found") return WebSearchEnhancedMessages(messages=req_messages, sources=[]) try: concise_query = await generate_search_query_from_llm( user_query, model_name, client ) + logger.info("Enhancing messages with web search context") return await enhance_messages_with_web_search(req_messages, concise_query) except HTTPException: + logger.exception("Web search provider error") return WebSearchEnhancedMessages(messages=req_messages, sources=[]) except Exception: + logger.exception("Unexpected error during web search handling") return WebSearchEnhancedMessages(messages=req_messages, sources=[]) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index c548b79f..6c303206 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -1,6 +1,8 @@ # Fast API and serving import asyncio import logging +import time +import uuid from base64 import b64encode from typing import AsyncGenerator, Optional, Union, List, Tuple from nilai_api.attestation import get_attestation_report @@ -22,15 +24,21 @@ from nilai_common import ( AttestationReport, ChatRequest, - Message, ModelMetadata, SignedChatCompletion, Nonce, Source, Usage, ) + +from openai.types.chat import ( + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, +) + from openai import AsyncOpenAI + logger = logging.getLogger(__name__) router = APIRouter() @@ -133,8 +141,8 @@ async def chat_completion( ChatRequest( model="meta-llama/Llama-3.2-1B-Instruct", messages=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="What is your name?"), + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is your name?"}, ], ) ), @@ -200,6 +208,9 @@ async def chat_completion( """ model_name = req.model + request_id = str(uuid.uuid4()) + t_start = time.monotonic() + logger.info(f"[chat] call start request_id={req.messages}") endpoint = await state.get_model(model_name) if endpoint is None: raise HTTPException( @@ -214,6 +225,7 @@ async def chat_completion( ) has_multimodal = has_multimodal_content(req.messages) + logger.info(f"[chat] has_multimodal: {has_multimodal}") if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): raise HTTPException( status_code=400, @@ -223,34 +235,43 @@ async def chat_completion( model_url = endpoint.url + "/v1/" logger.info( - f"Chat completion request for model {model_name} from user {auth_info.user.userid} on url: {model_url}" + f"[chat] start request_id={request_id} user={auth_info.user.userid} model={model_name} stream={req.stream} web_search={bool(req.web_search)} tools={bool(req.tools)} multimodal={has_multimodal} url={model_url}" ) client = AsyncOpenAI(base_url=model_url, api_key="") if req.nilrag: + logger.info(f"[chat] nilrag start request_id={request_id}") + t_nilrag = time.monotonic() await handle_nilrag(req) + logger.info( + f"[chat] nilrag done request_id={request_id} duration_ms={(time.monotonic() - t_nilrag) * 1000:.0f}" + ) messages = req.messages sources: Optional[List[Source]] = None if req.web_search: - if has_multimodal: - raise HTTPException( - status_code=400, - detail="Web search is not supported with multimodal (image) content. Use text-only input for web search.", - ) + logger.info(f"[chat] web_search start request_id={request_id}") + t_ws = time.monotonic() web_search_result = await handle_web_search(messages, model_name, client) messages = web_search_result.messages sources = web_search_result.sources + logger.info( + f"[chat] web_search done request_id={request_id} sources={len(sources) if sources else 0} duration_ms={(time.monotonic() - t_ws) * 1000:.0f}" + ) + logger.info(f"[chat] web_search messages: {messages}") if req.stream: # Forwarding Streamed Responses async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: try: + logger.info(f"[chat] stream start request_id={request_id}") + t_call = time.monotonic() + current_messages = messages request_kwargs = { "model": req.model, - "messages": messages, # type: ignore + "messages": current_messages, # type: ignore "stream": True, # type: ignore "top_p": req.top_p, "temperature": req.temperature, @@ -293,9 +314,12 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: prompt_tokens=prompt_token_usage, completion_tokens=completion_token_usage, ) + logger.info( + f"[chat] stream done request_id={request_id} prompt_tokens={prompt_token_usage} completion_tokens={completion_token_usage} duration_ms={(time.monotonic() - t_call) * 1000:.0f} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + ) except Exception as e: - logger.error(f"Error streaming response: {e}") + logger.error(f"[chat] stream error request_id={request_id} error={e}") return # Return the streaming response @@ -303,23 +327,32 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: chat_completion_stream_generator(), media_type="text/event-stream", # Ensure client interprets as Server-Sent Events ) + current_messages = messages request_kwargs = { "model": req.model, - "messages": messages, # type: ignore + "messages": current_messages, # type: ignore "top_p": req.top_p, "temperature": req.temperature, "max_tokens": req.max_tokens, } if req.tools: request_kwargs["tools"] = req.tools # type: ignore - + logger.info(f"[chat] call start request_id={request_id}") + logger.info(f"[chat] call message: {current_messages}") + t_call = time.monotonic() response = await client.chat.completions.create(**request_kwargs) # type: ignore - + logger.info( + f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) + logger.info(f"[chat] call response: {response}") model_response = SignedChatCompletion( **response.model_dump(), signature="", sources=sources, ) + logger.info( + f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}" + ) if model_response.usage is None: raise HTTPException( @@ -345,4 +378,7 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]: signature = sign_message(state.private_key, response_json) model_response.signature = b64encode(signature).decode() + logger.info( + f"[chat] done request_id={request_id} prompt_tokens={model_response.usage.prompt_tokens} completion_tokens={model_response.usage.completion_tokens} total_ms={(time.monotonic() - t_start) * 1000:.0f}" + ) return model_response diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index fdbd3ff8..d7dbf505 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -1,11 +1,39 @@ -from typing import Union, List, Optional, Iterable +from typing import Union, List, Optional, Iterable, cast, Any from openai.types.chat.chat_completion_content_part_text_param import ( ChatCompletionContentPartTextParam, ) from openai.types.chat.chat_completion_content_part_image_param import ( ChatCompletionContentPartImageParam, ) -from nilai_common import Message +from openai.types.chat import ChatCompletionMessageParam + + +def _iter_parts(content): + if content is None or isinstance(content, str): + return () + if isinstance(content, dict): + return (content,) + try: + iter(content) + return content + except TypeError: + return () + + +def _is_multimodal_part(part) -> bool: + if isinstance(part, dict): + if "image_url" in part or "input_audio" in part or "file" in part: + return True + t = part.get("type") + return t in {"image_url", "input_audio", "file"} + if ( + hasattr(part, "image_url") + or hasattr(part, "input_audio") + or hasattr(part, "file") + ): + return True + t = getattr(part, "type", None) + return t in {"image_url", "input_audio", "file"} def extract_text_content( @@ -13,56 +41,45 @@ def extract_text_content( str, Iterable[ Union[ - ChatCompletionContentPartTextParam, ChatCompletionContentPartImageParam + ChatCompletionContentPartTextParam, + ChatCompletionContentPartImageParam, + dict, ] ], + None, ], ) -> str: - """ - Extract text content from a message content field. - - Args: - content: Either a string or an iterable of content parts - - Returns: - str: The extracted text content, or empty string if no text content found - """ if isinstance(content, str): return content - elif hasattr(content, "__iter__") and not isinstance(content, str): - for part in content: - if part["type"] == "text": - return part["text"] + for part in _iter_parts(content): + if isinstance(part, dict): + if part.get("type") == "text": # type: ignore + txt = part.get("text") # type: ignore + if isinstance(txt, str): + return txt + else: + if getattr(part, "type", None) == "text": + txt = getattr(part, "text", None) + if isinstance(txt, str): + return txt return "" -def has_multimodal_content(messages: List[Message]) -> bool: - """Check if any message contains multimodal content with image_url type.""" - for msg in messages: - content = getattr(msg, "content", None) - if ( - content is not None - and hasattr(content, "__iter__") - and not isinstance(content, str) - ): - content_list = list(content) - for part in content_list: - if isinstance(part, dict) and part.get("type") == "image_url": - return True +def has_multimodal_content(messages: List[ChatCompletionMessageParam]) -> bool: + for m in messages: + for part in _iter_parts(m.get("content")): + if _is_multimodal_part(part): + return True return False -def get_last_user_query(messages: List[Message]) -> Optional[str]: - """ - Walk from the end to find the most recent user-authored content, and - extract text from either a string or a multimodal content list. - """ +def get_last_user_query(messages: List[ChatCompletionMessageParam]) -> Optional[str]: for msg in reversed(messages): - if getattr(msg, "role", None) == "user": - content = getattr(msg, "content", None) + if msg.get("role") == "user": + content = msg.get("content") if content is not None: try: - text = extract_text_content(content) + text = extract_text_content(content) # type: ignore except Exception: text = None if text: diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index c168e6d9..ca6307af 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -1,32 +1,11 @@ import uuid -from typing import Annotated, Iterable, List, Literal, Optional, Union +from typing import Annotated, Iterable, List, Optional -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat import ChatCompletion, ChatCompletionMessageParam from openai.types.chat import ChatCompletionToolParam from openai.types.chat.chat_completion import Choice as OpenaAIChoice -from openai.types.chat.chat_completion_content_part_image_param import ( - ChatCompletionContentPartImageParam, -) -from openai.types.chat.chat_completion_content_part_text_param import ( - ChatCompletionContentPartTextParam, -) -from pydantic import BaseModel, Field - - -class Message(ChatCompletionMessage): - role: Literal["system", "user", "assistant", "tool"] # type: ignore - content: Optional[ # type: ignore[reportIncompatibleVariableOverride] - Union[ - str, - Iterable[ - Union[ - ChatCompletionContentPartTextParam, - ChatCompletionContentPartImageParam, - ] - ], - ] - ] = None +from pydantic import BaseModel, Field, SkipValidation class Choice(OpenaAIChoice): @@ -45,7 +24,7 @@ class SearchResult(BaseModel): class WebSearchEnhancedMessages(BaseModel): - messages: List[Message] + messages: List[ChatCompletionMessageParam] sources: List[Source] @@ -58,7 +37,9 @@ class WebSearchContext(BaseModel): class ChatRequest(BaseModel): model: str - messages: List[Message] = Field(..., min_length=1) + messages: SkipValidation[List[ChatCompletionMessageParam]] = Field( + ..., min_length=1 + ) temperature: Optional[float] = Field(default=0.2, ge=0.0, le=5.0) top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0) max_tokens: Optional[int] = Field(default=2048, ge=1, le=100000) diff --git a/tests/e2e/config.py b/tests/e2e/config.py index c49eff97..3d67aa54 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -32,9 +32,7 @@ def api_key_getter(): "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", ], - "ci": [ - "meta-llama/Llama-3.2-1B-Instruct", - ], + "ci": ["meta-llama/Llama-3.2-1B-Instruct", "google/gemma-3-4b-it"], } if ENVIRONMENT not in models: diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index a0b7f486..f7e6ea86 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -258,6 +258,7 @@ def test_chat_completion_image_web_search_error( }, headers={"Authorization": "Bearer test-api-key"}, ) + print(response) assert response.status_code == 400 From d6979a1bba16d339558c31010cbdd11f710e0ce1 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 14:44:54 +0200 Subject: [PATCH 43/75] fix: ci yml --- .github/workflows/cicd.yml | 96 +++++++++++++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 354d9abb..a2c6aafd 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -137,7 +137,99 @@ jobs: sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env - name: Compose docker-compose.yml - run: bash ./scripts/docker-composer.sh --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -f docker/compose/docker-compose.gemma-4b-gpu.ci.yml -o development-compose.yml + run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -f docker/compose/docker-compose.gemma-4b-cpu.ci.yml development-compose.yml + + - name: GPU stack versions (non-fatal) + shell: bash + run: | + set +e # never fail this step + + echo "::group::Host & kernel" + uname -a || true + echo "Kernel: $(uname -r 2>/dev/null || echo unknown)" + test -e /var/run/reboot-required && echo "Reboot flag: PRESENT" || echo "Reboot flag: none" + echo "::endgroup::" + + echo "::group::NVIDIA driver" + if command -v nvidia-smi >/dev/null 2>&1; then + nvidia-smi || true + echo "Driver version (nvidia-smi): $(nvidia-smi --query-gpu=driver_version --format=csv,noheader 2>/dev/null | head -n1 || echo unknown)" + echo "GPU(s):"; nvidia-smi -L || true + else + echo "nvidia-smi: not found" + fi + if [ -r /proc/driver/nvidia/version ]; then + echo "--- /proc/driver/nvidia/version ---" + cat /proc/driver/nvidia/version || true + else + echo "/proc/driver/nvidia/version: not present" + fi + command -v modinfo >/dev/null 2>&1 && { echo "--- modinfo nvidia (head) ---"; modinfo nvidia 2>/dev/null | head -n 20 || true; } || true + echo "::endgroup::" + + echo "::group::DKMS status" + command -v dkms >/dev/null 2>&1 && dkms status | grep -i nvidia || echo "dkms or nvidia dkms info not present" + echo "::endgroup::" + + echo "::group::CUDA toolkit/runtime" + if command -v nvcc >/dev/null 2>&1; then + nvcc --version || true + else + echo "nvcc: not found" + fi + echo "libcudart in ldconfig:" + ldconfig -p 2>/dev/null | grep -i libcudart || echo "libcudart not found in ldconfig cache" + echo "NCCL packages:" + dpkg -l 2>/dev/null | grep -iE '^ii\s+libnccl' || echo "NCCL not installed (Debian/Ubuntu dpkg check)" + echo "::endgroup::" + + echo "::group::Container stack" + docker --version || echo "docker: not found" + docker info 2>/dev/null | grep -iE 'Runtimes|nvidia' || echo "docker info: no nvidia runtime line found" + containerd --version 2>/dev/null || echo "containerd: not found" + runc --version 2>/dev/null || echo "runc: not found" + echo "::endgroup::" + + echo "::group::NVIDIA container runtime/toolkit" + # Legacy/runtime binaries + if command -v nvidia-container-runtime >/dev/null 2>&1; then + nvidia-container-runtime --version || nvidia-container-runtime -v || true + else + echo "nvidia-container-runtime: not found" + fi + # Toolkit binaries (newer distros) + if command -v nvidia-ctk >/dev/null 2>&1; then + nvidia-ctk --version || true + nvidia-ctk runtime configure --help >/dev/null 2>&1 || true + else + echo "nvidia-ctk: not found" + fi + if command -v nvidia-container-toolkit >/dev/null 2>&1; then + nvidia-container-toolkit --version || true + else + echo "nvidia-container-toolkit: not found" + fi + echo "libnvidia-container packages:" + dpkg -l 2>/dev/null | grep -iE '^ii\s+(libnvidia-container1|libnvidia-container-tools)\s' || echo "libnvidia-container packages not found (dpkg)" + # Show runtime config if present + if [ -f /etc/nvidia-container-runtime/config.toml ]; then + echo "--- /etc/nvidia-container-runtime/config.toml (head) ---" + sed -n '1,120p' /etc/nvidia-container-runtime/config.toml || true + else + echo "/etc/nvidia-container-runtime/config.toml: not present" + fi + echo "::endgroup::" + + echo "::group::Apt logs (NVIDIA-related entries)" + for f in /var/log/apt/history.log /var/log/apt/term.log /var/log/unattended-upgrades/unattended-upgrades.log; do + if [[ -f "$f" ]]; then + echo "--- scanning $f" + grep -H -i -E 'nvidia|cuda|container-toolkit' "$f" || echo "no recent NVIDIA entries" + else + echo "missing: $f" + fi + done + echo "::endgroup::" - name: Start Services run: | @@ -230,4 +322,4 @@ jobs: mode: stop github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} label: ${{ needs.start-runner.outputs.label }} - ec2-instances-ids: ${{ needs.start-runner.outputs.ec2-instances-ids }} + ec2-instances-ids: ${{ needs.start-runner.outputs.ec2-instances-ids }} \ No newline at end of file From 9d6010389a9bd2c7c448fd7e9124203be2573c87 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 14:53:05 +0200 Subject: [PATCH 44/75] fix: ruff check --- nilai-api/src/nilai_api/routers/private.py | 4 ---- nilai-api/src/nilai_api/utils/content_extractor.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 6c303206..2381da08 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -31,10 +31,6 @@ Usage, ) -from openai.types.chat import ( - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, -) from openai import AsyncOpenAI diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py index d7dbf505..2d8b2092 100644 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ b/nilai-api/src/nilai_api/utils/content_extractor.py @@ -1,4 +1,4 @@ -from typing import Union, List, Optional, Iterable, cast, Any +from typing import Union, List, Optional, Iterable from openai.types.chat.chat_completion_content_part_text_param import ( ChatCompletionContentPartTextParam, ) From ed4735d9a5064b6ad6950d926a651d03bdda8503 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 15:16:53 +0200 Subject: [PATCH 45/75] fix: ci flag --- .github/workflows/cicd.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index a2c6aafd..acb30ae7 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -137,7 +137,7 @@ jobs: sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env - name: Compose docker-compose.yml - run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -f docker/compose/docker-compose.gemma-4b-cpu.ci.yml development-compose.yml + run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -f docker/compose/docker-compose.gemma-4b-cpu.ci.yml -o development-compose.yml - name: GPU stack versions (non-fatal) shell: bash From 37e5709b5aad9c04eab1408a0eb2fdca91194b25 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 15:29:20 +0200 Subject: [PATCH 46/75] fix: ci model --- .github/workflows/cicd.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index acb30ae7..f4c02b0a 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -137,7 +137,7 @@ jobs: sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env - name: Compose docker-compose.yml - run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -f docker/compose/docker-compose.gemma-4b-cpu.ci.yml -o development-compose.yml + run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -f docker/compose/docker-compose.gemma-4b-gpu.ci.yml -o development-compose.yml - name: GPU stack versions (non-fatal) shell: bash From 023dec00f76b0bac6c06dcecc06e901ccecacd66 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 15:52:13 +0200 Subject: [PATCH 47/75] fix: ci gemma configuration --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 81203986..2934a862 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -7,9 +7,9 @@ services: reservations: devices: - driver: nvidia - count: all + count: 1 capabilities: [gpu] - ipc: host + ulimits: memlock: -1 stack: 67108864 @@ -21,14 +21,13 @@ services: condition: service_healthy command: > --model google/gemma-3-4b-it - --gpu-memory-utilization 0.7 - --max-model-len 8192 - --max-num-batched-tokens 8192 + --gpu-memory-utilization 0.5 + --max-model-len 30000 + --max-num-batched-tokens 30000 --tensor-parallel-size 1 --enable-auto-tool-choice --tool-call-parser llama3_json --uvicorn-log-level warning - --dtype half environment: - SVC_HOST=gemma_4b_gpu - SVC_PORT=8000 @@ -36,8 +35,6 @@ services: - ETCD_PORT=2379 - TOOL_SUPPORT=true - MULTIMODAL_SUPPORT=true - - VLLM_USE_V1=1 - - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 - CUDA_LAUNCH_BLOCKING=1 volumes: - hugging_face_models:/root/.cache/huggingface From da0602f778160ca94825da2211b480bedebdcd5a Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 16:19:10 +0200 Subject: [PATCH 48/75] test#2: remove llama-1b --- .github/workflows/cicd.yml | 2 +- tests/e2e/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index f4c02b0a..9631f18d 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -137,7 +137,7 @@ jobs: sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env - name: Compose docker-compose.yml - run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-gpu.ci.yml -f docker/compose/docker-compose.gemma-4b-gpu.ci.yml -o development-compose.yml + run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.gemma-4b-gpu.ci.yml -o development-compose.yml - name: GPU stack versions (non-fatal) shell: bash diff --git a/tests/e2e/config.py b/tests/e2e/config.py index 3d67aa54..7caae4a4 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -32,7 +32,7 @@ def api_key_getter(): "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", ], - "ci": ["meta-llama/Llama-3.2-1B-Instruct", "google/gemma-3-4b-it"], + "ci": ["google/gemma-3-4b-it"], } if ENVIRONMENT not in models: From fccbd1d4457f4f6c2aec49a73ec3627149dc3de2 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 16:38:43 +0200 Subject: [PATCH 49/75] fix: update the script for gemma --- scripts/wait_for_ci_services.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/wait_for_ci_services.sh b/scripts/wait_for_ci_services.sh index 163fc50c..6194d8b4 100755 --- a/scripts/wait_for_ci_services.sh +++ b/scripts/wait_for_ci_services.sh @@ -2,7 +2,7 @@ # Wait for the services to be ready API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) -MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-llama_1b_gpu 2>/dev/null) +MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-gemma_4b_gpu 2>/dev/null) NUC_API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-nuc-api 2>/dev/null) MAX_ATTEMPTS=30 ATTEMPT=1 From 03ca9eb9c847746b3b12a10927fa3b2cb5196d1b Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 16:39:07 +0200 Subject: [PATCH 50/75] fix#2 --- scripts/wait_for_ci_services.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/wait_for_ci_services.sh b/scripts/wait_for_ci_services.sh index 6194d8b4..bd35621f 100755 --- a/scripts/wait_for_ci_services.sh +++ b/scripts/wait_for_ci_services.sh @@ -11,7 +11,7 @@ while [ $ATTEMPT -le $MAX_ATTEMPTS ]; do echo "Waiting for nilai to become healthy... API:[$API_HEALTH_STATUS] MODEL:[$MODEL_HEALTH_STATUS] NUC_API:[$NUC_API_HEALTH_STATUS] (Attempt $ATTEMPT/$MAX_ATTEMPTS)" sleep 30 API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) - MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-llama_1b_gpu 2>/dev/null) + MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-gemma_4b_gpu 2>/dev/null) NUC_API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-nuc-api 2>/dev/null) if [ "$API_HEALTH_STATUS" = "healthy" ] && [ "$MODEL_HEALTH_STATUS" = "healthy" ] && [ "$NUC_API_HEALTH_STATUS" = "healthy" ]; then break From 359f51861611bb81f5568450a08496575307d394 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 16:57:32 +0200 Subject: [PATCH 51/75] fix: add service startup logs --- scripts/wait_for_ci_services.sh | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/scripts/wait_for_ci_services.sh b/scripts/wait_for_ci_services.sh index bd35621f..829c0fc0 100755 --- a/scripts/wait_for_ci_services.sh +++ b/scripts/wait_for_ci_services.sh @@ -9,6 +9,14 @@ ATTEMPT=1 while [ $ATTEMPT -le $MAX_ATTEMPTS ]; do echo "Waiting for nilai to become healthy... API:[$API_HEALTH_STATUS] MODEL:[$MODEL_HEALTH_STATUS] NUC_API:[$NUC_API_HEALTH_STATUS] (Attempt $ATTEMPT/$MAX_ATTEMPTS)" + + # Check if any service is unhealthy and print logs + if [ "$API_HEALTH_STATUS" = "unhealthy" ]; then + echo "=== nilai-api is unhealthy, printing logs ===" + docker logs nilai-api --tail 50 + fi + + sleep 30 API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-gemma_4b_gpu 2>/dev/null) From 967eace7fb243eb3e359fa9308b0d4d9d7898e14 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 17:13:17 +0200 Subject: [PATCH 52/75] fix: update gemma ci config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 2934a862..4e0eb230 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -21,9 +21,10 @@ services: condition: service_healthy command: > --model google/gemma-3-4b-it - --gpu-memory-utilization 0.5 - --max-model-len 30000 - --max-num-batched-tokens 30000 + --model-impl vllm + --disable-sliding-window + --max-model-len 8192 + --max-num-batched-tokens 8192 --tensor-parallel-size 1 --enable-auto-tool-choice --tool-call-parser llama3_json @@ -36,6 +37,8 @@ services: - TOOL_SUPPORT=true - MULTIMODAL_SUPPORT=true - CUDA_LAUNCH_BLOCKING=1 + - VLLM_USE_V1=1 + - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 volumes: - hugging_face_models:/root/.cache/huggingface healthcheck: From 8b5a0733f79df0175ab844f43f2a4c53a9dfcbc0 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 17:33:37 +0200 Subject: [PATCH 53/75] fix: added logs for services --- scripts/wait_for_ci_services.sh | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/scripts/wait_for_ci_services.sh b/scripts/wait_for_ci_services.sh index 829c0fc0..3725259f 100755 --- a/scripts/wait_for_ci_services.sh +++ b/scripts/wait_for_ci_services.sh @@ -15,8 +15,17 @@ while [ $ATTEMPT -le $MAX_ATTEMPTS ]; do echo "=== nilai-api is unhealthy, printing logs ===" docker logs nilai-api --tail 50 fi - + if [ "$MODEL_HEALTH_STATUS" = "unhealthy" ]; then + echo "=== nilai-gemma_4b_gpu is unhealthy, printing logs ===" + docker logs nilai-gemma_4b_gpu --tail 50 + fi + + if [ "$NUC_API_HEALTH_STATUS" = "unhealthy" ]; then + echo "=== nilai-nuc-api is unhealthy, printing logs ===" + docker logs nilai-nuc-api --tail 50 + fi + sleep 30 API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-gemma_4b_gpu 2>/dev/null) @@ -31,17 +40,23 @@ done echo "API_HEALTH_STATUS: $API_HEALTH_STATUS" if [ "$API_HEALTH_STATUS" != "healthy" ]; then echo "Error: nilai-api failed to become healthy after $MAX_ATTEMPTS attempts" + echo "=== Final logs for nilai-api ===" + docker logs nilai-api --tail 100 exit 1 fi echo "MODEL_HEALTH_STATUS: $MODEL_HEALTH_STATUS" if [ "$MODEL_HEALTH_STATUS" != "healthy" ]; then - echo "Error: nilai-llama_1b_gpu failed to become healthy after $MAX_ATTEMPTS attempts" + echo "Error: nilai-gemma_4b_gpu failed to become healthy after $MAX_ATTEMPTS attempts" + echo "=== Final logs for nilai-gemma_4b_gpu ===" + docker logs nilai-gemma_4b_gpu --tail 100 exit 1 fi echo "NUC_API_HEALTH_STATUS: $NUC_API_HEALTH_STATUS" if [ "$NUC_API_HEALTH_STATUS" != "healthy" ]; then echo "Error: nilai-nuc-api failed to become healthy after $MAX_ATTEMPTS attempts" + echo "=== Final logs for nilai-nuc-api ===" + docker logs nilai-nuc-api --tail 100 exit 1 fi From e9cc0da151dd234d8764e4810d7b8e5ef18d3306 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 17:52:45 +0200 Subject: [PATCH 54/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 4e0eb230..7f40997e 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -37,7 +37,6 @@ services: - TOOL_SUPPORT=true - MULTIMODAL_SUPPORT=true - CUDA_LAUNCH_BLOCKING=1 - - VLLM_USE_V1=1 - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 volumes: - hugging_face_models:/root/.cache/huggingface From 3dcd4e5cdea330fa75c9e6cb9a9973352fe7bcef Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 18:10:34 +0200 Subject: [PATCH 55/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 7f40997e..41e8f883 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -22,12 +22,14 @@ services: command: > --model google/gemma-3-4b-it --model-impl vllm - --disable-sliding-window - --max-model-len 8192 - --max-num-batched-tokens 8192 + --dtype float16 --tensor-parallel-size 1 - --enable-auto-tool-choice - --tool-call-parser llama3_json + --max-model-len 1024 + --sliding-window 1024 + --max-num-batched-tokens 1024 + --max-num-seqs 8 + --gpu-memory-utilization 0.82 + --swap-space 8 --uvicorn-log-level warning environment: - SVC_HOST=gemma_4b_gpu @@ -38,6 +40,7 @@ services: - MULTIMODAL_SUPPORT=true - CUDA_LAUNCH_BLOCKING=1 - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 + - PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True volumes: - hugging_face_models:/root/.cache/huggingface healthcheck: From 86c8e0abba4ed710d1502f391cedefd0d60ef3b4 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 19:17:42 +0200 Subject: [PATCH 56/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 41e8f883..7aca8c53 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -25,7 +25,6 @@ services: --dtype float16 --tensor-parallel-size 1 --max-model-len 1024 - --sliding-window 1024 --max-num-batched-tokens 1024 --max-num-seqs 8 --gpu-memory-utilization 0.82 From cfc2e074cf749be23255f390475322acae32539a Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 19:48:13 +0200 Subject: [PATCH 57/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 7aca8c53..8b367d78 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -30,6 +30,7 @@ services: --gpu-memory-utilization 0.82 --swap-space 8 --uvicorn-log-level warning + --dtype bfloat16 environment: - SVC_HOST=gemma_4b_gpu - SVC_PORT=8000 From 96f78af6966c1f3b720be294b4aa0bdd7f707fce Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 20:22:57 +0200 Subject: [PATCH 58/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 8b367d78..58b5a6f1 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -30,7 +30,7 @@ services: --gpu-memory-utilization 0.82 --swap-space 8 --uvicorn-log-level warning - --dtype bfloat16 + --dtype=half environment: - SVC_HOST=gemma_4b_gpu - SVC_PORT=8000 From f1c7b4d6eda5502381237d4beddf430a054ce80d Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Fri, 29 Aug 2025 21:26:06 +0200 Subject: [PATCH 59/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 58b5a6f1..f0f023ea 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -22,7 +22,6 @@ services: command: > --model google/gemma-3-4b-it --model-impl vllm - --dtype float16 --tensor-parallel-size 1 --max-model-len 1024 --max-num-batched-tokens 1024 From f4451ca4d3578d840830c9543b33a6270badb2b2 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 09:15:51 +0200 Subject: [PATCH 60/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index f0f023ea..11733c00 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -29,7 +29,6 @@ services: --gpu-memory-utilization 0.82 --swap-space 8 --uvicorn-log-level warning - --dtype=half environment: - SVC_HOST=gemma_4b_gpu - SVC_PORT=8000 From 25bea107dd35b3ea70b8d109e1226dbdbdcbdc76 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 09:37:43 +0200 Subject: [PATCH 61/75] fix: update gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 11733c00..f0d04400 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -24,9 +24,9 @@ services: --model-impl vllm --tensor-parallel-size 1 --max-model-len 1024 - --max-num-batched-tokens 1024 - --max-num-seqs 8 - --gpu-memory-utilization 0.82 + --max-num-batched-tokens 256 + --max-num-seqs 4 + --gpu-memory-utilization 0.70 --swap-space 8 --uvicorn-log-level warning environment: From bd8ba991149a4b1c5a10f6f8aaef0da53e5c71e8 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 09:55:45 +0200 Subject: [PATCH 62/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index f0d04400..6b18eccb 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -23,9 +23,9 @@ services: --model google/gemma-3-4b-it --model-impl vllm --tensor-parallel-size 1 - --max-model-len 1024 - --max-num-batched-tokens 256 - --max-num-seqs 4 + --max-model-len 512 + --max-num-batched-tokens 512 + --max-num-seqs 1 --gpu-memory-utilization 0.70 --swap-space 8 --uvicorn-log-level warning From eb4dbe0a9adb451d95a4ac4825acaa7323f552d5 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 10:15:21 +0200 Subject: [PATCH 63/75] fix: gemma config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 6b18eccb..c3c3940f 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -29,6 +29,7 @@ services: --gpu-memory-utilization 0.70 --swap-space 8 --uvicorn-log-level warning + --quantization bitsandbytes environment: - SVC_HOST=gemma_4b_gpu - SVC_PORT=8000 From eb3f3de91db20a5e46a750084de8981927b3afcc Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 10:29:59 +0200 Subject: [PATCH 64/75] fix: use qwen-2b instead of gemma-4b for ci pipeline --- .../docker-compose.gemma-4b-gpu.ci.yml | 25 +++++++++++-------- scripts/wait_for_ci_services.sh | 14 +++++------ tests/e2e/config.py | 2 +- tests/e2e/test_openai.py | 22 ++++++++-------- tests/unit/nilai_api/routers/test_private.py | 4 +-- 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index c3c3940f..bb2f71b4 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -1,7 +1,9 @@ +version: "3.8" + services: - gemma_4b_gpu: + qwen2vl_2b_gpu: image: nillion/nilai-vllm:latest - container_name: nilai-gemma_4b_gpu + container_name: nilai-qwen2vl_2b_gpu deploy: resources: reservations: @@ -9,7 +11,6 @@ services: - driver: nvidia count: 1 capabilities: [gpu] - ulimits: memlock: -1 stack: 67108864 @@ -20,23 +21,24 @@ services: etcd: condition: service_healthy command: > - --model google/gemma-3-4b-it + --model Qwen/Qwen2-VL-2B-Instruct-AWQ --model-impl vllm --tensor-parallel-size 1 - --max-model-len 512 - --max-num-batched-tokens 512 - --max-num-seqs 1 - --gpu-memory-utilization 0.70 + --trust-remote-code + --max-model-len 768 + --max-num-batched-tokens 768 + --max-num-seqs 2 + --gpu-memory-utilization 0.80 --swap-space 8 --uvicorn-log-level warning - --quantization bitsandbytes + --quantization awq environment: - - SVC_HOST=gemma_4b_gpu + - SVC_HOST=qwen2vl_2b_gpu - SVC_PORT=8000 - ETCD_HOST=etcd - ETCD_PORT=2379 - TOOL_SUPPORT=true - - MULTIMODAL_SUPPORT=true + - MULTIMODAL_SUPPORT=true - CUDA_LAUNCH_BLOCKING=1 - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 - PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True @@ -48,5 +50,6 @@ services: retries: 3 start_period: 60s timeout: 10s + volumes: hugging_face_models: diff --git a/scripts/wait_for_ci_services.sh b/scripts/wait_for_ci_services.sh index 3725259f..36b2a75e 100755 --- a/scripts/wait_for_ci_services.sh +++ b/scripts/wait_for_ci_services.sh @@ -2,7 +2,7 @@ # Wait for the services to be ready API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) -MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-gemma_4b_gpu 2>/dev/null) +MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-qwen2vl_2b_gpu 2>/dev/null) NUC_API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-nuc-api 2>/dev/null) MAX_ATTEMPTS=30 ATTEMPT=1 @@ -17,8 +17,8 @@ while [ $ATTEMPT -le $MAX_ATTEMPTS ]; do fi if [ "$MODEL_HEALTH_STATUS" = "unhealthy" ]; then - echo "=== nilai-gemma_4b_gpu is unhealthy, printing logs ===" - docker logs nilai-gemma_4b_gpu --tail 50 + echo "=== nilai-qwen2vl_2b_gpu is unhealthy, printing logs ===" + docker logs nilai-qwen2vl_2b_gpu --tail 50 fi if [ "$NUC_API_HEALTH_STATUS" = "unhealthy" ]; then @@ -28,7 +28,7 @@ while [ $ATTEMPT -le $MAX_ATTEMPTS ]; do sleep 30 API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-api 2>/dev/null) - MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-gemma_4b_gpu 2>/dev/null) + MODEL_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-qwen2vl_2b_gpu 2>/dev/null) NUC_API_HEALTH_STATUS=$(docker inspect --format='{{.State.Health.Status}}' nilai-nuc-api 2>/dev/null) if [ "$API_HEALTH_STATUS" = "healthy" ] && [ "$MODEL_HEALTH_STATUS" = "healthy" ] && [ "$NUC_API_HEALTH_STATUS" = "healthy" ]; then break @@ -47,9 +47,9 @@ fi echo "MODEL_HEALTH_STATUS: $MODEL_HEALTH_STATUS" if [ "$MODEL_HEALTH_STATUS" != "healthy" ]; then - echo "Error: nilai-gemma_4b_gpu failed to become healthy after $MAX_ATTEMPTS attempts" - echo "=== Final logs for nilai-gemma_4b_gpu ===" - docker logs nilai-gemma_4b_gpu --tail 100 + echo "Error: nilai-qwen2vl_2b_gpu failed to become healthy after $MAX_ATTEMPTS attempts" + echo "=== Final logs for nilai-qwen2vl_2b_gpu ===" + docker logs nilai-qwen2vl_2b_gpu --tail 100 exit 1 fi diff --git a/tests/e2e/config.py b/tests/e2e/config.py index 7caae4a4..f3cefe89 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -32,7 +32,7 @@ def api_key_getter(): "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", ], - "ci": ["google/gemma-3-4b-it"], + "ci": ["Qwen/Qwen2-VL-2B-Instruct-AWQ"], } if ENVIRONMENT not in models: diff --git a/tests/e2e/test_openai.py b/tests/e2e/test_openai.py index e50c5d9e..1314c721 100644 --- a/tests/e2e/test_openai.py +++ b/tests/e2e/test_openai.py @@ -868,14 +868,14 @@ def make_request(): def test_multimodal_single_request(client): - """Test multimodal chat completion with a single request using gemma-3-4b-it model""" - if "google/gemma-3-4b-it" not in test_models: - pytest.skip("Multimodal test only runs for gemma-3-4b-it model") + """Test multimodal chat completion with a single request using Qwen/Qwen2-VL-2B-Instruct-AWQ model""" + if "Qwen/Qwen2-VL-2B-Instruct-AWQ" not in test_models: + pytest.skip("Multimodal test only runs for Qwen/Qwen2-VL-2B-Instruct-AWQ model") try: # Create a simple base64 encoded image (1x1 pixel red PNG) response = client.chat.completions.create( - model="google/gemma-3-4b-it", + model="Qwen/Qwen2-VL-2B-Instruct-AWQ", messages=[ { "role": "user", @@ -898,8 +898,8 @@ def test_multimodal_single_request(client): assert isinstance(response, ChatCompletion), ( "Response should be a ChatCompletion object" ) - assert response.model == "google/gemma-3-4b-it", ( - "Response model should be google/gemma-3-4b-it" + assert response.model == "Qwen/Qwen2-VL-2B-Instruct-AWQ", ( + "Response model should be Qwen/Qwen2-VL-2B-Instruct-AWQ" ) assert len(response.choices) > 0, "Response should contain at least one choice" @@ -932,9 +932,9 @@ def test_multimodal_single_request(client): def test_multimodal_consecutive_requests(client): - """Test two consecutive multimodal chat completions using gemma-3-4b-it model""" - if "google/gemma-3-4b-it" not in test_models: - pytest.skip("Multimodal test only runs for gemma-3-4b-it model") + """Test two consecutive multimodal chat completions using Qwen/Qwen2-VL-2B-Instruct-AWQ model""" + if "Qwen/Qwen2-VL-2B-Instruct-AWQ" not in test_models: + pytest.skip("Multimodal test only runs for Qwen/Qwen2-VL-2B-Instruct-AWQ model") try: # Create a simple base64 encoded image (1x1 pixel red PNG) @@ -1057,8 +1057,8 @@ def test_multimodal_consecutive_requests(client): def test_multimodal_with_web_search_error(client): """Test that multimodal + web search raises an error""" - if "google/gemma-3-4b-it" not in test_models: - pytest.skip("Multimodal test only runs for gemma-3-4b-it model") + if "Qwen/Qwen2-VL-2B-Instruct-AWQ" not in test_models: + pytest.skip("Multimodal test only runs for Qwen/Qwen2-VL-2B-Instruct-AWQ model") # Create a simple base64 encoded image (1x1 pixel red PNG) diff --git a/tests/unit/nilai_api/routers/test_private.py b/tests/unit/nilai_api/routers/test_private.py index f7e6ea86..e6e20cdc 100644 --- a/tests/unit/nilai_api/routers/test_private.py +++ b/tests/unit/nilai_api/routers/test_private.py @@ -238,7 +238,7 @@ def test_chat_completion_image_web_search_error( response = client.post( "/v1/chat/completions", json={ - "model": "google/gemma-3-4b-it", + "model": "Qwen/Qwen2-VL-2B-Instruct-AWQ", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { @@ -287,7 +287,7 @@ def test_chat_completion_with_image( response = client.post( "/v1/chat/completions", json={ - "model": "google/gemma-3-4b-it", + "model": "Qwen/Qwen2-VL-2B-Instruct-AWQ", "messages": [ {"role": "system", "content": "You are a helpful assistant."}, { From b0f36c63db86d96c08844668bd665d95b1a0163e Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 10:47:23 +0200 Subject: [PATCH 65/75] fix: update qwen config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index bb2f71b4..24f05d7a 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -1,7 +1,7 @@ version: "3.8" services: - qwen2vl_2b_gpu: + c: image: nillion/nilai-vllm:latest container_name: nilai-qwen2vl_2b_gpu deploy: @@ -25,8 +25,8 @@ services: --model-impl vllm --tensor-parallel-size 1 --trust-remote-code - --max-model-len 768 - --max-num-batched-tokens 768 + --max-model-len 2450 + --max-num-batched-tokens 2450 --max-num-seqs 2 --gpu-memory-utilization 0.80 --swap-space 8 From 7c2b14047d94da81e8b44d1cabb1bd8ac7b53765 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 11:04:48 +0200 Subject: [PATCH 66/75] fix: qwen config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 24f05d7a..37bbebdc 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -25,13 +25,15 @@ services: --model-impl vllm --tensor-parallel-size 1 --trust-remote-code - --max-model-len 2450 - --max-num-batched-tokens 2450 - --max-num-seqs 2 + --quantization awq + --max-model-len 1024 + --max-num-batched-tokens 1024 + --max-num-seqs 1 --gpu-memory-utilization 0.80 --swap-space 8 --uvicorn-log-level warning - --quantization awq + --limit-mm-per-prompt image=1,video=0 + --image-input-shape 896,896 environment: - SVC_HOST=qwen2vl_2b_gpu - SVC_PORT=8000 From 4375c032e554bf9978088626f852ba6332e1b753 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 11:21:32 +0200 Subject: [PATCH 67/75] fix: update qwen config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 37bbebdc..569ac56a 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -32,7 +32,7 @@ services: --gpu-memory-utilization 0.80 --swap-space 8 --uvicorn-log-level warning - --limit-mm-per-prompt image=1,video=0 + --limit-mm-per-prompt '{"image": 1, "video": 0}' --image-input-shape 896,896 environment: - SVC_HOST=qwen2vl_2b_gpu From 471c7cb17c3382a2d0f8986f92756aaf2dd4abb0 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 11:43:36 +0200 Subject: [PATCH 68/75] fix: config as list --- .../docker-compose.gemma-4b-gpu.ci.yml | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 569ac56a..c0478b58 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -20,20 +20,22 @@ services: depends_on: etcd: condition: service_healthy - command: > - --model Qwen/Qwen2-VL-2B-Instruct-AWQ - --model-impl vllm - --tensor-parallel-size 1 - --trust-remote-code - --quantization awq - --max-model-len 1024 - --max-num-batched-tokens 1024 - --max-num-seqs 1 - --gpu-memory-utilization 0.80 - --swap-space 8 - --uvicorn-log-level warning - --limit-mm-per-prompt '{"image": 1, "video": 0}' - --image-input-shape 896,896 + command: + [ + "--model", "Qwen/Qwen2-VL-2B-Instruct-AWQ", + "--model-impl", "vllm", + "--tensor-parallel-size", "1", + "--trust-remote-code", + "--quantization", "awq", + "--max-model-len", "1024", + "--max-num-batched-tokens", "1024", + "--max-num-seqs", "1", + "--gpu-memory-utilization", "0.80", + "--swap-space", "8", + "--uvicorn-log-level", "warning", + "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}", + "--image-input-shape", "896,896" + ] environment: - SVC_HOST=qwen2vl_2b_gpu - SVC_PORT=8000 From a842da816e826887bf177b09def4db4ff335c7f3 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 12:00:41 +0200 Subject: [PATCH 69/75] fix: qwen config --- .../docker-compose.gemma-4b-gpu.ci.yml | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index c0478b58..eb6f9039 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -20,22 +20,19 @@ services: depends_on: etcd: condition: service_healthy - command: - [ - "--model", "Qwen/Qwen2-VL-2B-Instruct-AWQ", - "--model-impl", "vllm", - "--tensor-parallel-size", "1", - "--trust-remote-code", - "--quantization", "awq", - "--max-model-len", "1024", - "--max-num-batched-tokens", "1024", - "--max-num-seqs", "1", - "--gpu-memory-utilization", "0.80", - "--swap-space", "8", - "--uvicorn-log-level", "warning", - "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}", - "--image-input-shape", "896,896" - ] + command: > + --model Qwen/Qwen2-VL-2B-Instruct-AWQ + --model-impl vllm + --tensor-parallel-size 1 + --trust-remote-code + --quantization awq + --max-model-len 1024 + --max-num-batched-tokens 1024 + --max-num-seqs 1 + --gpu-memory-utilization 0.80 + --swap-space 8 + --uvicorn-log-level warning + --limit-mm-per-prompt '{"image": 1, "video": 0}' environment: - SVC_HOST=qwen2vl_2b_gpu - SVC_PORT=8000 From 9115580131ab3a88e90a00c7b0ae31d87940bce2 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 12:16:33 +0200 Subject: [PATCH 70/75] fix: avoid parsing error --- .../docker-compose.gemma-4b-gpu.ci.yml | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index eb6f9039..b6969eed 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -20,19 +20,21 @@ services: depends_on: etcd: condition: service_healthy - command: > - --model Qwen/Qwen2-VL-2B-Instruct-AWQ - --model-impl vllm - --tensor-parallel-size 1 - --trust-remote-code - --quantization awq - --max-model-len 1024 - --max-num-batched-tokens 1024 - --max-num-seqs 1 - --gpu-memory-utilization 0.80 - --swap-space 8 - --uvicorn-log-level warning - --limit-mm-per-prompt '{"image": 1, "video": 0}' + command: + [ + "--model", "Qwen/Qwen2-VL-2B-Instruct-AWQ", + "--model-impl", "vllm", + "--tensor-parallel-size", "1", + "--trust-remote-code", + "--quantization", "awq", + "--max-model-len", "1024", + "--max-num-batched-tokens", "1024", + "--max-num-seqs", "1", + "--gpu-memory-utilization", "0.80", + "--swap-space", "8", + "--uvicorn-log-level", "warning", + "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}" + ] environment: - SVC_HOST=qwen2vl_2b_gpu - SVC_PORT=8000 From 30fb2c951bd09f7281fa952caea42af9205e10d1 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 12:30:19 +0200 Subject: [PATCH 71/75] fix: qwen config format --- .../docker-compose.gemma-4b-gpu.ci.yml | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index b6969eed..6c069b49 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -21,20 +21,20 @@ services: etcd: condition: service_healthy command: - [ - "--model", "Qwen/Qwen2-VL-2B-Instruct-AWQ", - "--model-impl", "vllm", - "--tensor-parallel-size", "1", - "--trust-remote-code", - "--quantization", "awq", - "--max-model-len", "1024", - "--max-num-batched-tokens", "1024", - "--max-num-seqs", "1", - "--gpu-memory-utilization", "0.80", - "--swap-space", "8", - "--uvicorn-log-level", "warning", - "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}" - ] + [ + "--model", "Qwen/Qwen2-VL-2B-Instruct-AWQ", + "--model-impl", "vllm", + "--tensor-parallel-size", "1", + "--trust-remote-code", + "--quantization", "awq", + "--max-model-len", "1024", + "--max-num-batched-tokens", "1024", + "--max-num-seqs", "1", + "--gpu-memory-utilization", "0.80", + "--swap-space", "8", + "--uvicorn-log-level", "warning", + "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}" + ] environment: - SVC_HOST=qwen2vl_2b_gpu - SVC_PORT=8000 From aa43f24c621f82e44c7429989764120cca85a04b Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 12:46:57 +0200 Subject: [PATCH 72/75] fix: update config --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 6c069b49..39d7f41d 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -27,10 +27,10 @@ services: "--tensor-parallel-size", "1", "--trust-remote-code", "--quantization", "awq", - "--max-model-len", "1024", - "--max-num-batched-tokens", "1024", + "--max-model-len", "1344", + "--max-num-batched-tokens", "1344", "--max-num-seqs", "1", - "--gpu-memory-utilization", "0.80", + "--gpu-memory-utilization", "0.75", "--swap-space", "8", "--uvicorn-log-level", "warning", "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}" From 8f6b06ecc186ca6378ccfb2eeb23e253dbb283de Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 13:07:22 +0200 Subject: [PATCH 73/75] fix: update config --- .../docker-compose.gemma-4b-gpu.ci.yml | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 39d7f41d..c42779e2 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -27,24 +27,29 @@ services: "--tensor-parallel-size", "1", "--trust-remote-code", "--quantization", "awq", - "--max-model-len", "1344", - "--max-num-batched-tokens", "1344", + + "--max-model-len", "1280", + "--max-num-batched-tokens", "1280", "--max-num-seqs", "1", + "--gpu-memory-utilization", "0.75", "--swap-space", "8", "--uvicorn-log-level", "warning", - "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}" + + "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}", + "--skip-mm-profiling" ] + environment: - - SVC_HOST=qwen2vl_2b_gpu - - SVC_PORT=8000 - - ETCD_HOST=etcd - - ETCD_PORT=2379 - - TOOL_SUPPORT=true - - MULTIMODAL_SUPPORT=true - - CUDA_LAUNCH_BLOCKING=1 - - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 - - PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + SVC_HOST: qwen2vl_2b_gpu + SVC_PORT: "8000" + ETCD_HOST: etcd + ETCD_PORT: "2379" + TOOL_SUPPORT: "true" + MULTIMODAL_SUPPORT: "true" + CUDA_LAUNCH_BLOCKING: "1" + VLLM_ALLOW_LONG_MAX_MODEL_LEN: "1" + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" volumes: - hugging_face_models:/root/.cache/huggingface healthcheck: From 244d14d7298ce5546f8230afacd6fef6d5678a35 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort Date: Mon, 1 Sep 2025 13:28:05 +0200 Subject: [PATCH 74/75] fix: enfore eager --- docker/compose/docker-compose.gemma-4b-gpu.ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index c42779e2..9b991011 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -37,7 +37,8 @@ services: "--uvicorn-log-level", "warning", "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}", - "--skip-mm-profiling" + "--skip-mm-profiling", + "--enforce-eager" ] environment: From fd9ab4393230357e08c3943f510955dca538a172 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Cabrero-Holgueras?= Date: Tue, 2 Sep 2025 09:45:59 +0000 Subject: [PATCH 75/75] fix: api model fixes --- .github/workflows/cicd.yml | 6 +- docker-compose.dev.yml | 4 - .../compose/docker-compose.gemma-27b-gpu.yml | 13 +- .../docker-compose.gemma-4b-gpu.ci.yml | 51 ++---- .../compose/docker-compose.llama-8b-gpu.yml | 2 +- .../compose/docker-compose.qwen-2b-gpu.ci.yml | 64 +++++++ nilai-api/src/nilai_api/handlers/nilrag.py | 21 +-- .../src/nilai_api/handlers/web_search.py | 59 +++--- nilai-api/src/nilai_api/routers/private.py | 19 +- .../src/nilai_api/utils/content_extractor.py | 87 --------- .../nilai-common/src/nilai_common/__init__.py | 4 + .../src/nilai_common/api_model.py | 173 +++++++++++++++++- tests/e2e/config.py | 2 +- tests/unit/__init__.py | 4 +- 14 files changed, 319 insertions(+), 190 deletions(-) create mode 100644 docker/compose/docker-compose.qwen-2b-gpu.ci.yml delete mode 100644 nilai-api/src/nilai_api/utils/content_extractor.py diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 9631f18d..d07935a5 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -137,10 +137,10 @@ jobs: sed -i 's/BRAVE_SEARCH_API=.*/BRAVE_SEARCH_API=${{ secrets.BRAVE_SEARCH_API }}/' .env - name: Compose docker-compose.yml - run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.gemma-4b-gpu.ci.yml -o development-compose.yml + run: python3 ./scripts/docker-composer.py --dev -f docker/compose/docker-compose.llama-1b-cpu.ci.yml -o development-compose.yml - name: GPU stack versions (non-fatal) - shell: bash + shell: bash run: | set +e # never fail this step @@ -322,4 +322,4 @@ jobs: mode: stop github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} label: ${{ needs.start-runner.outputs.label }} - ec2-instances-ids: ${{ needs.start-runner.outputs.ec2-instances-ids }} \ No newline at end of file + ec2-instances-ids: ${{ needs.start-runner.outputs.ec2-instances-ids }} diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 41ef9771..25066491 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -2,10 +2,6 @@ services: caddy: env_file: - .env - ports: - - "80:80" - - "443:443" - - "443:443/udp" volumes: - ./caddy/Caddyfile:/etc/caddy/Caddyfile api: diff --git a/docker/compose/docker-compose.gemma-27b-gpu.yml b/docker/compose/docker-compose.gemma-27b-gpu.yml index 95a76721..754b44c3 100644 --- a/docker/compose/docker-compose.gemma-27b-gpu.yml +++ b/docker/compose/docker-compose.gemma-27b-gpu.yml @@ -20,19 +20,18 @@ services: condition: service_healthy command: > --model google/gemma-3-27b-it - --gpu-memory-utilization 0.95 + --gpu-memory-utilization 0.79 --max-model-len 60000 - --max-num-batched-tokens 60000 - --tensor-parallel-size 1 - --enable-auto-tool-choice - --tool-call-parser llama3_json + --max-num-batched-tokens 8192 + --dtype bfloat16 + --kv-cache-dtype fp8 --uvicorn-log-level warning environment: - SVC_HOST=gemma_27b_gpu - SVC_PORT=8000 - ETCD_HOST=etcd - ETCD_PORT=2379 - - TOOL_SUPPORT=true + - TOOL_SUPPORT=false - MULTIMODAL_SUPPORT=true volumes: - hugging_face_models:/root/.cache/huggingface @@ -43,4 +42,4 @@ services: start_period: 60s timeout: 10s volumes: - hugging_face_models: \ No newline at end of file + hugging_face_models: diff --git a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml index 9b991011..29423275 100644 --- a/docker/compose/docker-compose.gemma-4b-gpu.ci.yml +++ b/docker/compose/docker-compose.gemma-4b-gpu.ci.yml @@ -1,9 +1,7 @@ -version: "3.8" - services: - c: + gemma_4b_gpu: image: nillion/nilai-vllm:latest - container_name: nilai-qwen2vl_2b_gpu + container_name: nilai-gemma_4b_gpu deploy: resources: reservations: @@ -11,6 +9,7 @@ services: - driver: nvidia count: 1 capabilities: [gpu] + ulimits: memlock: -1 stack: 67108864 @@ -20,37 +19,22 @@ services: depends_on: etcd: condition: service_healthy - command: - [ - "--model", "Qwen/Qwen2-VL-2B-Instruct-AWQ", - "--model-impl", "vllm", - "--tensor-parallel-size", "1", - "--trust-remote-code", - "--quantization", "awq", - - "--max-model-len", "1280", - "--max-num-batched-tokens", "1280", - "--max-num-seqs", "1", - - "--gpu-memory-utilization", "0.75", - "--swap-space", "8", - "--uvicorn-log-level", "warning", - - "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}", - "--skip-mm-profiling", - "--enforce-eager" - ] + command: > + --model google/gemma-3-4b-it + --max-model-len 30000 + --max-num-batched-tokens 8192 + --uvicorn-log-level warning environment: - SVC_HOST: qwen2vl_2b_gpu - SVC_PORT: "8000" - ETCD_HOST: etcd - ETCD_PORT: "2379" - TOOL_SUPPORT: "true" - MULTIMODAL_SUPPORT: "true" - CUDA_LAUNCH_BLOCKING: "1" - VLLM_ALLOW_LONG_MAX_MODEL_LEN: "1" - PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + - SVC_HOST=gemma_4b_gpu + - SVC_PORT=8000 + - ETCD_HOST=etcd + - ETCD_PORT=2379 + - TOOL_SUPPORT=false + - MULTIMODAL_SUPPORT=true + - CUDA_LAUNCH_BLOCKING=1 + - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 + - PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True volumes: - hugging_face_models:/root/.cache/huggingface healthcheck: @@ -59,6 +43,5 @@ services: retries: 3 start_period: 60s timeout: 10s - volumes: hugging_face_models: diff --git a/docker/compose/docker-compose.llama-8b-gpu.yml b/docker/compose/docker-compose.llama-8b-gpu.yml index 18392928..7ecdba10 100644 --- a/docker/compose/docker-compose.llama-8b-gpu.yml +++ b/docker/compose/docker-compose.llama-8b-gpu.yml @@ -20,7 +20,7 @@ services: condition: service_healthy command: > --model meta-llama/Llama-3.1-8B-Instruct - --gpu-memory-utilization 0.21 + --gpu-memory-utilization 0.20 --max-model-len 10000 --max-num-batched-tokens 10000 --tensor-parallel-size 1 diff --git a/docker/compose/docker-compose.qwen-2b-gpu.ci.yml b/docker/compose/docker-compose.qwen-2b-gpu.ci.yml new file mode 100644 index 00000000..7d040caf --- /dev/null +++ b/docker/compose/docker-compose.qwen-2b-gpu.ci.yml @@ -0,0 +1,64 @@ +version: "3.8" + +services: + c: + image: nillion/nilai-vllm:latest + container_name: qwen2vl_2b_gpu + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + ulimits: + memlock: -1 + stack: 67108864 + env_file: + - .env + restart: unless-stopped + depends_on: + etcd: + condition: service_healthy + command: + [ + "--model", "Qwen/Qwen2-VL-2B-Instruct-AWQ", + "--model-impl", "vllm", + "--tensor-parallel-size", "1", + "--trust-remote-code", + "--quantization", "awq", + + "--max-model-len", "1280", + "--max-num-batched-tokens", "1280", + "--max-num-seqs", "1", + + "--gpu-memory-utilization", "0.75", + "--swap-space", "8", + "--uvicorn-log-level", "warning", + + "--limit-mm-per-prompt", "{\"image\":1,\"video\":0}", + "--skip-mm-profiling", + "--enforce-eager" + ] + + environment: + SVC_HOST: qwen2vl_2b_gpu + SVC_PORT: "8000" + ETCD_HOST: etcd + ETCD_PORT: "2379" + TOOL_SUPPORT: "true" + MULTIMODAL_SUPPORT: "true" + CUDA_LAUNCH_BLOCKING: "1" + VLLM_ALLOW_LONG_MAX_MODEL_LEN: "1" + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + volumes: + - hugging_face_models:/root/.cache/huggingface + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + retries: 3 + start_period: 60s + timeout: 10s + +volumes: + hugging_face_models: diff --git a/nilai-api/src/nilai_api/handlers/nilrag.py b/nilai-api/src/nilai_api/handlers/nilrag.py index b3121071..630d9088 100644 --- a/nilai-api/src/nilai_api/handlers/nilrag.py +++ b/nilai-api/src/nilai_api/handlers/nilrag.py @@ -3,10 +3,9 @@ import nilrag -from nilai_common import ChatRequest +from nilai_common import ChatRequest, MessageAdapter from fastapi import HTTPException, status from sentence_transformers import SentenceTransformer -from nilai_api.utils.content_extractor import extract_text_content logger = logging.getLogger(__name__) @@ -64,11 +63,7 @@ async def handle_nilrag(req: ChatRequest): # Get user query logger.debug("Extracting user query") - query = None - for message in req.messages: - if message.get("role") == "user" and message.get("content") is not None: - query = extract_text_content(message.get("content")) # type: ignore - break + query = req.get_last_user_query() if not query: raise HTTPException(status_code=400, detail="No user query found") @@ -86,9 +81,9 @@ async def handle_nilrag(req: ChatRequest): relevant_context = f"\n\nRelevant Context:\n{formatted_results}" # Step 4: Update system message - for message in req.messages: - if message.get("role") == "system": - content = message.get("content") + for message in req.adapted_messages: + if message.role == "system": + content = message.content if content is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -96,13 +91,15 @@ async def handle_nilrag(req: ChatRequest): ) if isinstance(content, str): - message["content"] = content + relevant_context + message.content = content + relevant_context elif isinstance(content, list): content.append({"type": "text", "text": relevant_context}) break else: # If no system message exists, add one - req.messages.insert(0, {"role": "system", "content": relevant_context}) + req.messages.insert( + 0, MessageAdapter.new_message(role="system", content=relevant_context) + ) logger.debug(f"System message updated with relevant context:\n {req.messages}") diff --git a/nilai-api/src/nilai_api/handlers/web_search.py b/nilai-api/src/nilai_api/handlers/web_search.py index ac7f70b1..3dfff7af 100644 --- a/nilai-api/src/nilai_api/handlers/web_search.py +++ b/nilai-api/src/nilai_api/handlers/web_search.py @@ -6,14 +6,15 @@ from fastapi import HTTPException, status from nilai_api.config import WEB_SEARCH_SETTINGS -from nilai_api.utils.content_extractor import get_last_user_query from nilai_common.api_model import ( + ChatRequest, + Message, + MessageAdapter, SearchResult, Source, WebSearchEnhancedMessages, WebSearchContext, ) -from openai.types.chat import ChatCompletionMessageParam logger = logging.getLogger(__name__) @@ -182,7 +183,7 @@ def _get_role_and_content(msg): async def enhance_messages_with_web_search( - messages: List[ChatCompletionMessageParam], query: str + messages: List[Message], query: str ) -> WebSearchEnhancedMessages: """Enhance a list of messages with web search context.Collapse commentComment on line L155jcabrero commented on Aug 26, 2025 jcabreroon Aug 26, 2025MemberDeleted docstring?Write a replyResolve commentCode has comments. Press enter to view. @@ -206,31 +207,37 @@ async def enhance_messages_with_web_search( "Please provide a comprehensive answer based on the search results above." ) - enhanced: List[ChatCompletionMessageParam] = [] + enhanced: List[Message] = [] system_message_added = False for msg in messages: - role, content = _get_role_and_content(msg) - - if role == "system" and not system_message_added: - if isinstance(content, str): - combined_content = content + "\n\n" + web_search_content - elif isinstance(content, list): + adapted_message = MessageAdapter(raw=msg) + + if adapted_message.role == "system" and not system_message_added: + if isinstance(adapted_message.content, str): + combined_content_str = ( + adapted_message.content + "\n\n" + web_search_content + ) + elif isinstance(adapted_message.content, list): # content is likely a list of parts (for multimodal); append a text part - parts = list(content) + parts = list(adapted_message.content) parts.append({"type": "text", "text": "\n\n" + web_search_content}) - combined_content = parts + combined_content_str = parts else: - combined_content = web_search_content - - enhanced.append({"role": "system", "content": combined_content}) # type: ignore + combined_content_str = web_search_content + enhanced.append( + MessageAdapter.new_message(role="system", content=combined_content_str) + ) system_message_added = True else: # Re-append in dict form - enhanced.append({"role": role or "user", "content": content}) # type: ignore + + enhanced.append(adapted_message.to_openai_param()) if not system_message_added: - enhanced.insert(0, {"role": "system", "content": web_search_content}) + enhanced.insert( + 0, MessageAdapter.new_message(role="system", content=web_search_content) + ) return WebSearchEnhancedMessages( messages=enhanced, @@ -292,7 +299,7 @@ async def generate_search_query_from_llm( async def handle_web_search( - req_messages: List[ChatCompletionMessageParam], model_name: str, client + req_messages: ChatRequest, model_name: str, client ) -> WebSearchEnhancedMessages: """Handle web search enhancement for a conversation. @@ -310,24 +317,28 @@ async def handle_web_search( """ logger.info("Handle web search start") logger.debug( - "Handle web search messages_in=%d model=%s", len(req_messages), model_name + "Handle web search messages_in=%d model=%s", + len(req_messages.messages), + model_name, ) - user_query = get_last_user_query(req_messages) + user_query = req_messages.get_last_user_query() if not user_query: logger.info("No user query found") - return WebSearchEnhancedMessages(messages=req_messages, sources=[]) + return WebSearchEnhancedMessages(messages=req_messages.messages, sources=[]) try: concise_query = await generate_search_query_from_llm( user_query, model_name, client ) logger.info("Enhancing messages with web search context") - return await enhance_messages_with_web_search(req_messages, concise_query) + return await enhance_messages_with_web_search( + req_messages.messages, concise_query + ) except HTTPException: logger.exception("Web search provider error") - return WebSearchEnhancedMessages(messages=req_messages, sources=[]) + return WebSearchEnhancedMessages(messages=req_messages.messages, sources=[]) except Exception: logger.exception("Unexpected error during web search handling") - return WebSearchEnhancedMessages(messages=req_messages, sources=[]) + return WebSearchEnhancedMessages(messages=req_messages.messages, sources=[]) diff --git a/nilai-api/src/nilai_api/routers/private.py b/nilai-api/src/nilai_api/routers/private.py index 6a2fc740..376e6570 100644 --- a/nilai-api/src/nilai_api/routers/private.py +++ b/nilai-api/src/nilai_api/routers/private.py @@ -8,7 +8,6 @@ from nilai_api.attestation import get_attestation_report from nilai_api.handlers.nilrag import handle_nilrag from nilai_api.handlers.web_search import handle_web_search -from nilai_api.utils.content_extractor import has_multimodal_content from fastapi import APIRouter, Body, Depends, HTTPException, status, Request from fastapi.responses import StreamingResponse @@ -25,6 +24,7 @@ AttestationReport, ChatRequest, ModelMetadata, + MessageAdapter, SignedChatCompletion, Nonce, Source, @@ -136,8 +136,10 @@ async def chat_completion( ChatRequest( model="meta-llama/Llama-3.2-1B-Instruct", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, + MessageAdapter.new_message( + role="system", content="You are a helpful assistant." + ), + MessageAdapter.new_message(role="user", content="What is your name?"), ], ) ), @@ -196,12 +198,17 @@ async def chat_completion( model="meta-llama/Llama-3.2-1B-Instruct", messages=[ {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": "What's the latest news about AI?"} + {"role": "user", "content": "What is your name?"} ], ) response = await chat_completion(request, user) """ + if len(req.messages) == 0: + raise HTTPException( + status_code=400, + detail="Request contained 0 messages", + ) model_name = req.model request_id = str(uuid.uuid4()) t_start = time.monotonic() @@ -219,7 +226,7 @@ async def chat_completion( detail="Model does not support tool usage, remove tools from request", ) - has_multimodal = has_multimodal_content(req.messages) + has_multimodal = req.has_multimodal_content() logger.info(f"[chat] has_multimodal: {has_multimodal}") if has_multimodal and (not endpoint.metadata.multimodal_support or req.web_search): raise HTTPException( @@ -249,7 +256,7 @@ async def chat_completion( if req.web_search: logger.info(f"[chat] web_search start request_id={request_id}") t_ws = time.monotonic() - web_search_result = await handle_web_search(messages, model_name, client) + web_search_result = await handle_web_search(req, model_name, client) messages = web_search_result.messages sources = web_search_result.sources logger.info( diff --git a/nilai-api/src/nilai_api/utils/content_extractor.py b/nilai-api/src/nilai_api/utils/content_extractor.py deleted file mode 100644 index 2d8b2092..00000000 --- a/nilai-api/src/nilai_api/utils/content_extractor.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Union, List, Optional, Iterable -from openai.types.chat.chat_completion_content_part_text_param import ( - ChatCompletionContentPartTextParam, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ChatCompletionContentPartImageParam, -) -from openai.types.chat import ChatCompletionMessageParam - - -def _iter_parts(content): - if content is None or isinstance(content, str): - return () - if isinstance(content, dict): - return (content,) - try: - iter(content) - return content - except TypeError: - return () - - -def _is_multimodal_part(part) -> bool: - if isinstance(part, dict): - if "image_url" in part or "input_audio" in part or "file" in part: - return True - t = part.get("type") - return t in {"image_url", "input_audio", "file"} - if ( - hasattr(part, "image_url") - or hasattr(part, "input_audio") - or hasattr(part, "file") - ): - return True - t = getattr(part, "type", None) - return t in {"image_url", "input_audio", "file"} - - -def extract_text_content( - content: Union[ - str, - Iterable[ - Union[ - ChatCompletionContentPartTextParam, - ChatCompletionContentPartImageParam, - dict, - ] - ], - None, - ], -) -> str: - if isinstance(content, str): - return content - for part in _iter_parts(content): - if isinstance(part, dict): - if part.get("type") == "text": # type: ignore - txt = part.get("text") # type: ignore - if isinstance(txt, str): - return txt - else: - if getattr(part, "type", None) == "text": - txt = getattr(part, "text", None) - if isinstance(txt, str): - return txt - return "" - - -def has_multimodal_content(messages: List[ChatCompletionMessageParam]) -> bool: - for m in messages: - for part in _iter_parts(m.get("content")): - if _is_multimodal_part(part): - return True - return False - - -def get_last_user_query(messages: List[ChatCompletionMessageParam]) -> Optional[str]: - for msg in reversed(messages): - if msg.get("role") == "user": - content = msg.get("content") - if content is not None: - try: - text = extract_text_content(content) # type: ignore - except Exception: - text = None - if text: - return text.strip() - return None diff --git a/packages/nilai-common/src/nilai_common/__init__.py b/packages/nilai-common/src/nilai_common/__init__.py index 77a4e666..e29eef27 100644 --- a/packages/nilai-common/src/nilai_common/__init__.py +++ b/packages/nilai-common/src/nilai_common/__init__.py @@ -13,12 +13,16 @@ Source, WebSearchEnhancedMessages, WebSearchContext, + Message, + MessageAdapter, ) from nilai_common.config import SETTINGS from nilai_common.discovery import ModelServiceDiscovery from openai.types.completion_usage import CompletionUsage as Usage __all__ = [ + "Message", + "MessageAdapter", "ChatRequest", "SignedChatCompletion", "Choice", diff --git a/packages/nilai-common/src/nilai_common/api_model.py b/packages/nilai-common/src/nilai_common/api_model.py index ca6307af..3bb9485f 100644 --- a/packages/nilai-common/src/nilai_common/api_model.py +++ b/packages/nilai-common/src/nilai_common/api_model.py @@ -1,13 +1,41 @@ -import uuid +from __future__ import annotations -from typing import Annotated, Iterable, List, Optional +import uuid +from typing import ( + Annotated, + Iterable, + List, + Optional, + Any, + cast, + TypeAlias, + Literal, + Union, +) -from openai.types.chat import ChatCompletion, ChatCompletionMessageParam -from openai.types.chat import ChatCompletionToolParam +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessageParam, + ChatCompletionToolParam, + ChatCompletionMessage, +) +from openai.types.chat.chat_completion_content_part_text_param import ( + ChatCompletionContentPartTextParam, +) +from openai.types.chat.chat_completion_content_part_image_param import ( + ChatCompletionContentPartImageParam, +) from openai.types.chat.chat_completion import Choice as OpenaAIChoice -from pydantic import BaseModel, Field, SkipValidation +from pydantic import BaseModel, Field + + +# ---------- Aliases from the OpenAI SDK ---------- +ImageContent: TypeAlias = ChatCompletionContentPartImageParam +TextContent: TypeAlias = ChatCompletionContentPartTextParam +Message: TypeAlias = ChatCompletionMessageParam # SDK union of message shapes +# ---------- Models you already had ---------- class Choice(OpenaAIChoice): pass @@ -23,8 +51,105 @@ class SearchResult(BaseModel): url: str +# ---------- Helpers ---------- +def _extract_text_from_content(content: Any) -> Optional[str]: + """ + - If content is a str -> return it (stripped) if non-empty. + - If content is a list of content parts -> concatenate 'text' parts. + - Else -> None. + """ + if isinstance(content, str): + s = content.strip() + return s or None + if isinstance(content, list): + parts: List[str] = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + t = part.get("text") + if isinstance(t, str) and t.strip(): + parts.append(t.strip()) + if parts: + return "\n".join(parts) + return None + + +# ---------- Adapter over the raw SDK message ---------- +class MessageAdapter(BaseModel): + """Thin wrapper around an OpenAI ChatCompletionMessageParam with convenience methods.""" + + raw: Message + + @property + def role(self) -> str: + return cast(str, self.raw.get("role")) + + @role.setter + def role( + self, + value: Literal["developer", "user", "system", "assistant", "tool", "function"], + ) -> None: + if not isinstance(value, str): + raise TypeError("role must be a string") + # Update the underlying SDK message dict + # Cast to Any to bypass TypedDict restrictions + cast(Any, self.raw)["role"] = value + + @property + def content(self) -> Any: + return self.raw.get("content") + + @content.setter + def content(self, value: Any) -> None: + # Update the underlying SDK message dict + # Cast to Any to bypass TypedDict restrictions + cast(Any, self.raw)["content"] = value + + @staticmethod + def new_message( + role: Literal["developer", "user", "system", "assistant", "tool", "function"], + content: Union[str, List[Any]], + ) -> Message: + message: Message = cast(Message, {"role": role, "content": content}) + return message + + @staticmethod + def new_completion_message(content: str) -> ChatCompletionMessage: + message: ChatCompletionMessage = cast( + ChatCompletionMessage, {"role": "assistant", "content": content} + ) + return message + + def is_text_part(self) -> bool: + return _extract_text_from_content(self.content) is not None + + def is_multimodal_part(self) -> bool: + c = self.content + if isinstance(c, str): + return False + + for part in c: + if isinstance(part, dict) and part.get("type") in ( + "image_url", + "input_image", + ): + return True + return False + + def extract_text(self) -> Optional[str]: + return _extract_text_from_content(self.content) + + def to_openai_param(self) -> Message: + # Return the original dict for API calls. + return self.raw + + +def adapt_messages(msgs: List[Message]) -> List[MessageAdapter]: + return [MessageAdapter(raw=m) for m in msgs] + + +# ---------- Your additional containers ---------- class WebSearchEnhancedMessages(BaseModel): - messages: List[ChatCompletionMessageParam] + messages: List[Message] sources: List[Source] @@ -35,11 +160,10 @@ class WebSearchContext(BaseModel): sources: List[Source] +# ---------- Request/response models ---------- class ChatRequest(BaseModel): model: str - messages: SkipValidation[List[ChatCompletionMessageParam]] = Field( - ..., min_length=1 - ) + messages: List[Message] = Field(..., min_length=1) temperature: Optional[float] = Field(default=0.2, ge=0.0, le=5.0) top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0) max_tokens: Optional[int] = Field(default=2048, ge=1, le=100000) @@ -51,6 +175,36 @@ class ChatRequest(BaseModel): description="Enable web search to enhance context with current information", ) + def model_post_init(self, __context) -> None: + # Process messages after model initialization + for i, msg in enumerate(self.messages): + content = msg.get("content") + if ( + content is not None + and hasattr(content, "__iter__") + and hasattr(content, "__next__") + ): + # Convert iterator to list in place + cast(Any, msg)["content"] = list(content) + + @property + def adapted_messages(self) -> List[MessageAdapter]: + return adapt_messages(self.messages) + + def get_last_user_query(self) -> Optional[str]: + """ + Returns the latest non-empty user text (plain or from content parts), + or None if not found. + """ + for m in reversed(self.adapted_messages): + if m.role == "user" and m.is_text_part(): + return m.extract_text() + return None + + def has_multimodal_content(self) -> bool: + """True if any message contains an image content part.""" + return any([m.is_multimodal_part() for m in self.adapted_messages]) + class SignedChatCompletion(ChatCompletion): signature: str @@ -82,6 +236,7 @@ class HealthCheckResponse(BaseModel): uptime: str +# ---------- Attestation ---------- Nonce = Annotated[ str, Field( diff --git a/tests/e2e/config.py b/tests/e2e/config.py index f3cefe89..7caae4a4 100644 --- a/tests/e2e/config.py +++ b/tests/e2e/config.py @@ -32,7 +32,7 @@ def api_key_getter(): "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", ], - "ci": ["Qwen/Qwen2-VL-2B-Instruct-AWQ"], + "ci": ["google/gemma-3-4b-it"], } if ENVIRONMENT not in models: diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index ec5d5b07..4b43a545 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,5 +1,4 @@ from openai.types.chat.chat_completion import ChoiceLogprobs -from openai.types.chat import ChatCompletionMessage from nilai_common import ( SignedChatCompletion, @@ -7,6 +6,7 @@ ModelMetadata, Usage, Choice, + MessageAdapter, ) model_metadata: ModelMetadata = ModelMetadata( @@ -33,7 +33,7 @@ choices=[ Choice( index=0, - message=ChatCompletionMessage(role="assistant", content="test-content"), + message=MessageAdapter.new_completion_message(content="test-content"), finish_reason="stop", logprobs=ChoiceLogprobs(), )