diff --git a/.circleci/config.yml b/.circleci/config.yml index 5aeebc1074c..5e441ade02d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -839,6 +839,52 @@ jobs: paths: - guardrails_coverage.xml - guardrails_coverage + + google_generate_content_endpoint_testing: + docker: + - image: cimg/python:3.11 + auth: + username: ${DOCKERHUB_USERNAME} + password: ${DOCKERHUB_PASSWORD} + working_directory: ~/project + + steps: + - checkout + - setup_google_dns + - run: + name: Install Dependencies + command: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + pip install "pytest==7.3.1" + pip install "pytest-retry==1.6.3" + pip install "pytest-cov==5.0.0" + pip install "pytest-asyncio==0.21.1" + pip install "respx==0.22.0" + pip install "pydantic==2.10.2" + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/unified_google_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m + - run: + name: Rename the coverage files + command: | + mv coverage.xml google_generate_content_endpoint_coverage.xml + mv .coverage google_generate_content_endpoint_coverage + + # Store test results + - store_test_results: + path: test-results + - persist_to_workspace: + root: . + paths: + - google_generate_content_endpoint_coverage.xml + - google_generate_content_endpoint_coverage + llm_responses_api_testing: docker: - image: cimg/python:3.11 @@ -3001,6 +3047,12 @@ workflows: only: - main - /litellm_.*/ + - google_generate_content_endpoint_testing: + filters: + branches: + only: + - main + - /litellm_.*/ - llm_responses_api_testing: filters: branches: @@ -3047,6 +3099,7 @@ workflows: requires: - llm_translation_testing - mcp_testing + - google_generate_content_endpoint_testing - guardrails_testing - llm_responses_api_testing - litellm_mapped_tests @@ -3106,6 +3159,7 @@ workflows: - test_bad_database_url - llm_translation_testing - mcp_testing + - google_generate_content_endpoint_testing - llm_responses_api_testing - litellm_mapped_tests - batches_testing diff --git a/litellm/google_genai/main.py b/litellm/google_genai/main.py index c34b1663e6f..87970885355 100644 --- a/litellm/google_genai/main.py +++ b/litellm/google_genai/main.py @@ -24,11 +24,14 @@ GenerateContentConfigDict, GenerateContentContentListUnionDict, GenerateContentResponse, + ToolConfigDict, ) else: GenerateContentConfigDict = Any GenerateContentContentListUnionDict = Any GenerateContentResponse = Any + ToolConfigDict = Any + ####### ENVIRONMENT VARIABLES ################### # Initialize any necessary instances or variables here @@ -83,6 +86,7 @@ def setup_generate_content_call( config: Optional[GenerateContentConfigDict] = None, custom_llm_provider: Optional[str] = None, stream: bool = False, + tools: Optional[ToolConfigDict] = None, **kwargs, ) -> GenerateContentSetupResult: """ @@ -166,6 +170,7 @@ def setup_generate_content_call( generate_content_provider_config.transform_generate_content_request( model=model, contents=contents, + tools=tools, generate_content_config_dict=generate_content_config_dict, ) ) @@ -200,6 +205,7 @@ async def agenerate_content( model: str, contents: GenerateContentContentListUnionDict, config: Optional[GenerateContentConfigDict] = None, + tools: Optional[ToolConfigDict] = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Dict[str, Any]] = None, @@ -235,6 +241,7 @@ async def agenerate_content( extra_body=extra_body, timeout=timeout, custom_llm_provider=custom_llm_provider, + tools=tools, **kwargs, ) @@ -263,6 +270,7 @@ def generate_content( model: str, contents: GenerateContentContentListUnionDict, config: Optional[GenerateContentConfigDict] = None, + tools: Optional[ToolConfigDict] = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Dict[str, Any]] = None, @@ -296,6 +304,7 @@ def generate_content( config=config, custom_llm_provider=custom_llm_provider, stream=False, + tools=tools, **kwargs, ) @@ -316,6 +325,7 @@ def generate_content( response = base_llm_http_handler.generate_content_handler( model=setup_result.model, contents=contents, + tools=tools, generate_content_provider_config=setup_result.generate_content_provider_config, generate_content_config_dict=setup_result.generate_content_config_dict, custom_llm_provider=setup_result.custom_llm_provider, @@ -346,6 +356,7 @@ async def agenerate_content_stream( model: str, contents: GenerateContentContentListUnionDict, config: Optional[GenerateContentConfigDict] = None, + tools: Optional[ToolConfigDict] = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Dict[str, Any]] = None, @@ -377,6 +388,7 @@ async def agenerate_content_stream( "config": config, "custom_llm_provider": custom_llm_provider, "stream": True, + "tools": tools, **kwargs, } ) @@ -402,6 +414,7 @@ async def agenerate_content_stream( contents=contents, generate_content_provider_config=setup_result.generate_content_provider_config, generate_content_config_dict=setup_result.generate_content_config_dict, + tools=tools, custom_llm_provider=setup_result.custom_llm_provider, litellm_params=setup_result.litellm_params, logging_obj=setup_result.litellm_logging_obj, @@ -429,6 +442,7 @@ def generate_content_stream( model: str, contents: GenerateContentContentListUnionDict, config: Optional[GenerateContentConfigDict] = None, + tools: Optional[ToolConfigDict] = None, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Dict[str, Any]] = None, @@ -454,6 +468,7 @@ def generate_content_stream( config=config, custom_llm_provider=custom_llm_provider, stream=True, + tools=tools, **kwargs, ) @@ -476,6 +491,7 @@ def generate_content_stream( contents=contents, generate_content_provider_config=setup_result.generate_content_provider_config, generate_content_config_dict=setup_result.generate_content_config_dict, + tools=tools, custom_llm_provider=setup_result.custom_llm_provider, litellm_params=setup_result.litellm_params, logging_obj=setup_result.litellm_logging_obj, diff --git a/litellm/llms/base_llm/google_genai/transformation.py b/litellm/llms/base_llm/google_genai/transformation.py index 9706b226c47..6dbccaada9a 100644 --- a/litellm/llms/base_llm/google_genai/transformation.py +++ b/litellm/llms/base_llm/google_genai/transformation.py @@ -10,12 +10,14 @@ GenerateContentConfigDict, GenerateContentContentListUnionDict, GenerateContentResponse, + ToolConfigDict, ) else: GenerateContentConfigDict = Any GenerateContentContentListUnionDict = Any GenerateContentResponse = Any LiteLLMLoggingObj = Any + ToolConfigDict = Any from litellm.types.router import GenericLiteLLMParams @@ -145,6 +147,7 @@ def transform_generate_content_request( self, model: str, contents: GenerateContentContentListUnionDict, + tools: Optional[ToolConfigDict], generate_content_config_dict: Dict, ) -> dict: """ @@ -153,6 +156,7 @@ def transform_generate_content_request( Args: model: The model name contents: Input contents + tools: Tools generate_content_request_params: Request parameters litellm_params: LiteLLM parameters headers: Request headers diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 46fd866be2b..3e7dff91820 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -3196,6 +3196,7 @@ def generate_content_handler( contents: Any, generate_content_provider_config: BaseGoogleGenAIGenerateContentConfig, generate_content_config_dict: Dict, + tools: Any, custom_llm_provider: str, litellm_params: GenericLiteLLMParams, logging_obj: LiteLLMLoggingObj, @@ -3221,6 +3222,7 @@ def generate_content_handler( contents=contents, generate_content_provider_config=generate_content_provider_config, generate_content_config_dict=generate_content_config_dict, + tools=tools, custom_llm_provider=custom_llm_provider, litellm_params=litellm_params, logging_obj=logging_obj, @@ -3256,6 +3258,7 @@ def generate_content_handler( data = generate_content_provider_config.transform_generate_content_request( model=model, contents=contents, + tools=tools, generate_content_config_dict=generate_content_config_dict, ) @@ -3317,6 +3320,7 @@ async def async_generate_content_handler( contents: Any, generate_content_provider_config: BaseGoogleGenAIGenerateContentConfig, generate_content_config_dict: Dict, + tools: Any, custom_llm_provider: str, litellm_params: GenericLiteLLMParams, logging_obj: LiteLLMLoggingObj, @@ -3360,6 +3364,7 @@ async def async_generate_content_handler( data = generate_content_provider_config.transform_generate_content_request( model=model, contents=contents, + tools=tools, generate_content_config_dict=generate_content_config_dict, ) diff --git a/litellm/llms/gemini/google_genai/transformation.py b/litellm/llms/gemini/google_genai/transformation.py index 5a07dd13a97..28142f72739 100644 --- a/litellm/llms/gemini/google_genai/transformation.py +++ b/litellm/llms/gemini/google_genai/transformation.py @@ -18,11 +18,14 @@ GenerateContentConfigDict, GenerateContentContentListUnionDict, GenerateContentResponse, + ToolConfigDict, ) else: GenerateContentConfigDict = Any GenerateContentContentListUnionDict = Any GenerateContentResponse = Any + ToolConfigDict = Any + from ..common_utils import get_api_key_from_env class GoogleGenAIConfig(BaseGoogleGenAIGenerateContentConfig, VertexLLM): @@ -258,6 +261,7 @@ def transform_generate_content_request( self, model: str, contents: GenerateContentContentListUnionDict, + tools: Optional[ToolConfigDict], generate_content_config_dict: Dict, ) -> dict: from litellm.types.google_genai.main import ( @@ -267,6 +271,7 @@ def transform_generate_content_request( typed_generate_content_request = GenerateContentRequestDict( model=model, contents=contents, + tools=tools, generationConfig=GenerateContentConfigDict(**generate_content_config_dict), ) diff --git a/litellm/llms/vertex_ai/google_genai/transformation.py b/litellm/llms/vertex_ai/google_genai/transformation.py index 02825026e1b..47933811196 100644 --- a/litellm/llms/vertex_ai/google_genai/transformation.py +++ b/litellm/llms/vertex_ai/google_genai/transformation.py @@ -1,16 +1,39 @@ """ Transformation for Calling Google models in their native format. """ -from typing import Literal +from typing import Literal, Optional, Union from litellm.llms.gemini.google_genai.transformation import GoogleGenAIConfig +from litellm.types.router import GenericLiteLLMParams class VertexAIGoogleGenAIConfig(GoogleGenAIConfig): """ Configuration for calling Google models in their native format. """ + HEADER_NAME = "Authorization" + BEARER_PREFIX = "Bearer" + @property def custom_llm_provider(self) -> Literal["gemini", "vertex_ai"]: return "vertex_ai" + + + def validate_environment( + self, + api_key: Optional[str], + headers: Optional[dict], + model: str, + litellm_params: Optional[Union[GenericLiteLLMParams, dict]] + ) -> dict: + default_headers = { + "Content-Type": "application/json", + } + + if api_key is not None: + default_headers[self.HEADER_NAME] = f"{self.BEARER_PREFIX} {api_key}" + if headers is not None: + default_headers.update(headers) + + return default_headers \ No newline at end of file diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index dd61d402f4f..68d7c6786f8 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,9 +1,7 @@ model_list: - - model_name: openai/* + - model_name: vertex_ai/* litellm_params: - model: openai/* - + model: vertex_ai/* litellm_settings: - success_callback: ["mlflow"] - failure_callback: ["mlflow"] \ No newline at end of file + callbacks: ["datadog_llm_observability"] diff --git a/litellm/types/google_genai/main.py b/litellm/types/google_genai/main.py index 33da5a06f49..96abc5d2ae8 100644 --- a/litellm/types/google_genai/main.py +++ b/litellm/types/google_genai/main.py @@ -15,9 +15,11 @@ GenerateContentContentListUnionDict = _genai_types.ContentListUnionDict GenerateContentConfigDict = _genai_types.GenerateContentConfigDict GenerateContentRequestParametersDict = _genai_types._GenerateContentParametersDict +ToolConfigDict = _genai_types.ToolConfigDict class GenerateContentRequestDict(GenerateContentRequestParametersDict): # type: ignore[misc] generationConfig: Optional[Any] + tools: Optional[ToolConfigDict] class GenerateContentResponse(GoogleGenAIGenerateContentResponse, BaseLiteLLMOpenAIResponseObject): # type: ignore[misc] diff --git a/tests/unified_google_tests/base_google_test.py b/tests/unified_google_tests/base_google_test.py index df6a886867f..28c70b1cea7 100644 --- a/tests/unified_google_tests/base_google_test.py +++ b/tests/unified_google_tests/base_google_test.py @@ -22,10 +22,11 @@ from litellm.types.utils import StandardLoggingPayload -@pytest.fixture(scope="session") -def load_vertex_ai_credentials(): - """Fixture to load Vertex AI credentials for all tests""" +def load_vertex_ai_credentials(model: str): + """Load Vertex AI credentials for tests""" # Define the path to the vertex_key.json file + if "vertex_ai" not in model: + return None print("loading vertex ai credentials") filepath = os.path.dirname(os.path.abspath(__file__)) vertex_key_path = filepath + "/vertex_key.json" @@ -63,14 +64,7 @@ def load_vertex_ai_credentials(): # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) - # Yield the path for tests that might need it - yield os.path.abspath(temp_file.name) - - # Cleanup: remove the temporary file after all tests complete - try: - os.unlink(temp_file.name) - except OSError: - pass # File might already be deleted + return os.path.abspath(temp_file.name) class TestCustomLogger(CustomLogger): @@ -94,6 +88,22 @@ def model_config(self) -> Dict[str, Any]: """Override in subclasses to provide model-specific configuration""" raise NotImplementedError("Subclasses must implement model_config") + @property + def _temp_files_to_cleanup(self): + """Lazy initialization of temp files list""" + if not hasattr(self, '_temp_files_list'): + self._temp_files_list = [] + return self._temp_files_list + + def cleanup_temp_files(self): + """Clean up any temporary files created during testing""" + for temp_file in self._temp_files_to_cleanup: + try: + os.unlink(temp_file) + except OSError: + pass # File might already be deleted + self._temp_files_to_cleanup.clear() + def _validate_non_streaming_response(self, response: Any): """Validate non-streaming response structure""" @@ -156,6 +166,10 @@ async def test_non_streaming_base(self, is_async: bool): ], role="user", ) + temp_file_path = load_vertex_ai_credentials(model=request_params["model"]) + if temp_file_path: + self._temp_files_to_cleanup.append(temp_file_path) + litellm._turn_on_debug() print(f"Testing {'async' if is_async else 'sync'} non-streaming with model config: {request_params}") @@ -184,6 +198,9 @@ async def test_non_streaming_base(self, is_async: bool): async def test_streaming_base(self, is_async: bool): """Base test for streaming requests (parametrized for sync/async)""" request_params = self.model_config + temp_file_path = load_vertex_ai_credentials(model=request_params["model"]) + if temp_file_path: + self._temp_files_to_cleanup.append(temp_file_path) contents = ContentDict( parts=[ PartDict( @@ -231,6 +248,9 @@ async def test_async_non_streaming_with_logging(self): litellm.callbacks = [test_custom_logger] request_params = self.model_config + temp_file_path = load_vertex_ai_credentials(model=request_params["model"]) + if temp_file_path: + self._temp_files_to_cleanup.append(temp_file_path) contents = ContentDict( parts=[ PartDict( @@ -272,6 +292,9 @@ async def test_async_streaming_with_logging(self): litellm.callbacks = [test_custom_logger] request_params = self.model_config + temp_file_path = load_vertex_ai_credentials(model=request_params["model"]) + if temp_file_path: + self._temp_files_to_cleanup.append(temp_file_path) contents = ContentDict( parts=[ PartDict( diff --git a/tests/unified_google_tests/test_google_ai_studio.py b/tests/unified_google_tests/test_google_ai_studio.py index d662c189cfe..e6859647b69 100644 --- a/tests/unified_google_tests/test_google_ai_studio.py +++ b/tests/unified_google_tests/test_google_ai_studio.py @@ -1,4 +1,13 @@ from base_google_test import BaseGoogleGenAITest +import sys +import os +sys.path.insert( + 0, os.path.abspath("../../..") +) # Adds the parent directory to the system path +import pytest +import litellm +import unittest.mock +import json class TestGoogleGenAIStudio(BaseGoogleGenAITest): """Test Google GenAI Studio""" @@ -7,4 +16,388 @@ class TestGoogleGenAIStudio(BaseGoogleGenAITest): def model_config(self): return { "model": "gemini/gemini-1.5-flash", - } \ No newline at end of file + } + +@pytest.mark.asyncio +async def test_mock_stream_generate_content_with_tools(): + """Test streaming function call response parsing and validation""" + from litellm.types.google_genai.main import ToolConfigDict + litellm._turn_on_debug() + contents = [ + { + "role": "user", + "parts": [ + {"text": "Schedule a meeting with Bob and Alice for 03/27/2025 at 10:00 AM about the Q3 planning"} + ] + } + ] + + # Mock streaming response chunks that represent a function call response + mock_response_chunk = { + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "schedule_meeting", + "args": { + "attendees": ["Bob", "Alice"], + "date": "2025-03-27", + "time": "10:00", + "topic": "Q3 planning" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "usageMetadata": { + "promptTokenCount": 15, + "candidatesTokenCount": 5, + "totalTokenCount": 20 + } + } + + # Convert to bytes as expected by the streaming iterator + raw_chunks = [ + f"data: {json.dumps(mock_response_chunk)}\n\n".encode(), + b"data: [DONE]\n\n" + ] + + # Mock the HTTP handler + with unittest.mock.patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=unittest.mock.AsyncMock) as mock_post: + # Create mock response object + mock_response = unittest.mock.MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + + # Mock the aiter_bytes method to return our chunks as bytes + async def mock_aiter_bytes(): + for chunk in raw_chunks: + yield chunk + + mock_response.aiter_bytes = mock_aiter_bytes + mock_post.return_value = mock_response + + print("\n--- Testing async agenerate_content_stream with function call parsing ---") + response = await litellm.google_genai.agenerate_content_stream( + model="gemini/gemini-1.5-flash", + contents=contents, + tools=[ + { + "functionDeclarations": [ + { + "name": "schedule_meeting", + "description": "Schedules a meeting with specified attendees at a given time and date.", + "parameters": { + "type": "object", + "properties": { + "attendees": { + "type": "array", + "items": {"type": "string"}, + "description": "List of people attending the meeting." + }, + "date": { + "type": "string", + "description": "Date of the meeting (e.g., '2024-07-29')" + }, + "time": { + "type": "string", + "description": "Time of the meeting (e.g., '15:00')" + }, + "topic": { + "type": "string", + "description": "The subject or topic of the meeting." + } + }, + "required": ["attendees", "date", "time", "topic"] + } + } + ] + } + ] + ) + + # Collect all chunks and parse function calls + chunks = [] + function_calls = [] + + chunk_count = 0 + async for chunk in response: + chunk_count += 1 + print(f"Received chunk {chunk_count}: {chunk}") + chunks.append(chunk) + + # Stop after a reasonable number of chunks to prevent infinite loop + if chunk_count > 10: + break + + # Parse function calls from byte chunks + if isinstance(chunk, bytes): + try: + # Decode bytes to string + chunk_str = chunk.decode('utf-8') + print(f"Decoded chunk: {chunk_str}") + + # Extract JSON from Server-Sent Events format (data: {...}) + if chunk_str.startswith('data: ') and not chunk_str.startswith('data: [DONE]'): + json_str = chunk_str[6:].strip() # Remove 'data: ' prefix + try: + parsed_json = json.loads(json_str) + print(f"Parsed JSON: {parsed_json}") + + # Parse function calls from the JSON + if "candidates" in parsed_json: + for candidate in parsed_json["candidates"]: + if "content" in candidate and "parts" in candidate["content"]: + for part in candidate["content"]["parts"]: + if "functionCall" in part: + function_calls.append({ + 'name': part["functionCall"]["name"], + 'args': part["functionCall"]["args"] + }) + print(f"Found function call: {part['functionCall']}") + except json.JSONDecodeError as e: + print(f"Failed to parse JSON: {e}") + except UnicodeDecodeError as e: + print(f"Failed to decode bytes: {e}") + + # Handle dict responses (in case some chunks are already parsed) + elif isinstance(chunk, dict): + # Direct dict response + if "candidates" in chunk: + for candidate in chunk["candidates"]: + if "content" in candidate and "parts" in candidate["content"]: + for part in candidate["content"]["parts"]: + if "functionCall" in part: + function_calls.append({ + 'name': part["functionCall"]["name"], + 'args': part["functionCall"]["args"] + }) + + # Handle object responses with attributes + elif hasattr(chunk, 'candidates') and chunk.candidates: + for candidate in chunk.candidates: + if hasattr(candidate, 'content') and candidate.content: + if hasattr(candidate.content, 'parts') and candidate.content.parts: + for part in candidate.content.parts: + if hasattr(part, 'function_call') and part.function_call: + function_calls.append({ + 'name': part.function_call.name, + 'args': part.function_call.args + }) + + # Assertions + print(f"\nFunction calls found: {function_calls}") + print(f"Total chunks received: {chunk_count}") + + # Assert we found at least one function call + assert len(function_calls) > 0, "Expected at least one function call in the streaming response" + + # Check the first function call + function_call = function_calls[0] + + # Assert function name + assert function_call['name'] == "schedule_meeting", f"Expected function name 'schedule_meeting', got '{function_call['name']}'" + + # Assert function arguments + args = function_call['args'] + assert "attendees" in args, "Expected 'attendees' in function call arguments" + assert "date" in args, "Expected 'date' in function call arguments" + assert "time" in args, "Expected 'time' in function call arguments" + assert "topic" in args, "Expected 'topic' in function call arguments" + + # Assert specific argument values + assert args["attendees"] == ["Bob", "Alice"], f"Expected attendees ['Bob', 'Alice'], got {args['attendees']}" + assert args["date"] == "2025-03-27", f"Expected date '2025-03-27', got {args['date']}" + assert args["time"] == "10:00", f"Expected time '10:00', got {args['time']}" + assert args["topic"] == "Q3 planning", f"Expected topic 'Q3 planning', got {args['topic']}" + + print("✅ All function call assertions passed!") + +@pytest.mark.asyncio +async def test_validate_post_request_parameters(): + """ + Test that the correct parameters are sent in the POST request to Google GenAI API + + Params validated + 1. model + 2. contents + 3. tools + """ + from litellm.types.google_genai.main import ToolConfigDict + + contents = [ + { + "role": "user", + "parts": [ + {"text": "Schedule a meeting with Bob and Alice for 03/27/2025 at 10:00 AM about the Q3 planning"} + ] + } + ] + + tools = [ + { + "functionDeclarations": [ + { + "name": "schedule_meeting", + "description": "Schedules a meeting with specified attendees at a given time and date.", + "parameters": { + "type": "object", + "properties": { + "attendees": { + "type": "array", + "items": {"type": "string"}, + "description": "List of people attending the meeting." + }, + "date": { + "type": "string", + "description": "Date of the meeting (e.g., '2024-07-29')" + }, + "time": { + "type": "string", + "description": "Time of the meeting (e.g., '15:00')" + }, + "topic": { + "type": "string", + "description": "The subject or topic of the meeting." + } + }, + "required": ["attendees", "date", "time", "topic"] + } + } + ] + } + ] + + # Mock response for the HTTP request + raw_chunks = [ + b"data: [DONE]\n\n" + ] + + # Mock the HTTP handler to capture the request + with unittest.mock.patch("litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", new_callable=unittest.mock.AsyncMock) as mock_post: + # Create mock response object + mock_response = unittest.mock.MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + + # Mock the aiter_bytes method + async def mock_aiter_bytes(): + for chunk in raw_chunks: + yield chunk + + mock_response.aiter_bytes = mock_aiter_bytes + mock_post.return_value = mock_response + + print("\n--- Testing POST request parameters validation ---") + + # Make the API call + response = await litellm.google_genai.agenerate_content_stream( + model="gemini/gemini-1.5-flash", + contents=contents, + tools=tools + ) + + # Consume the response to ensure the request is made + async for chunk in response: + pass + + # Validate that the HTTP post was called + assert mock_post.called, "Expected HTTP POST to be called" + + # Get the call arguments + call_args, call_kwargs = mock_post.call_args + + print(f"POST call args: {call_args}") + print(f"POST call kwargs: {call_kwargs}") + + # Validate URL contains the correct endpoint + if call_args: + url = call_args[0] if len(call_args) > 0 else call_kwargs.get('url') + assert url is not None, "Expected URL to be provided" + assert "generativelanguage.googleapis.com" in url, f"Expected Google API URL, got: {url}" + assert "streamGenerateContent" in url, f"Expected streamGenerateContent endpoint, got: {url}" + print(f"✅ URL validation passed: {url}") + + # Get the request data/json from the call + request_data = None + if 'data' in call_kwargs: + # If data is passed as bytes, decode it + if isinstance(call_kwargs['data'], bytes): + request_data = json.loads(call_kwargs['data'].decode('utf-8')) + else: + request_data = call_kwargs['data'] + elif 'json' in call_kwargs: + request_data = call_kwargs['json'] + + assert request_data is not None, "Expected request data to be provided" + print(f"Request data: {json.dumps(request_data, indent=2)}") + + # Validate model field + assert "model" in request_data, "Expected 'model' field in request data" + # Model might be transformed, but should contain gemini-1.5-flash + model_value = request_data["model"] + assert "gemini-1.5-flash" in model_value, f"Expected model to contain 'gemini-1.5-flash', got: {model_value}" + print(f"✅ Model validation passed: {model_value}") + + # Validate contents field + assert "contents" in request_data, "Expected 'contents' field in request data" + request_contents = request_data["contents"] + assert isinstance(request_contents, list), "Expected contents to be a list" + assert len(request_contents) > 0, "Expected at least one content item" + + # Check the first content item + first_content = request_contents[0] + assert "role" in first_content, "Expected 'role' in content item" + assert first_content["role"] == "user", f"Expected role 'user', got: {first_content['role']}" + assert "parts" in first_content, "Expected 'parts' in content item" + assert isinstance(first_content["parts"], list), "Expected parts to be a list" + assert len(first_content["parts"]) > 0, "Expected at least one part" + + # Check the text content + first_part = first_content["parts"][0] + assert "text" in first_part, "Expected 'text' in part" + expected_text = "Schedule a meeting with Bob and Alice for 03/27/2025 at 10:00 AM about the Q3 planning" + assert first_part["text"] == expected_text, f"Expected text '{expected_text}', got: {first_part['text']}" + print(f"✅ Contents validation passed") + + # Validate tools field + assert "tools" in request_data, "Expected 'tools' field in request data" + request_tools = request_data["tools"] + assert isinstance(request_tools, list), "Expected tools to be a list" + assert len(request_tools) > 0, "Expected at least one tool" + + # Check the first tool + first_tool = request_tools[0] + assert "functionDeclarations" in first_tool, "Expected 'functionDeclarations' in tool" + function_declarations = first_tool["functionDeclarations"] + assert isinstance(function_declarations, list), "Expected functionDeclarations to be a list" + assert len(function_declarations) > 0, "Expected at least one function declaration" + + # Check the function declaration + func_decl = function_declarations[0] + assert "name" in func_decl, "Expected 'name' in function declaration" + assert func_decl["name"] == "schedule_meeting", f"Expected function name 'schedule_meeting', got: {func_decl['name']}" + assert "description" in func_decl, "Expected 'description' in function declaration" + assert "parameters" in func_decl, "Expected 'parameters' in function declaration" + + # Check function parameters + params = func_decl["parameters"] + assert "type" in params, "Expected 'type' in parameters" + assert params["type"] == "object", f"Expected parameters type 'object', got: {params['type']}" + assert "properties" in params, "Expected 'properties' in parameters" + assert "required" in params, "Expected 'required' in parameters" + + # Check required fields + required_fields = params["required"] + expected_required = ["attendees", "date", "time", "topic"] + assert set(required_fields) == set(expected_required), f"Expected required fields {expected_required}, got: {required_fields}" + print(f"✅ Tools validation passed") + + print("✅ All POST request parameter validations passed!") \ No newline at end of file diff --git a/tests/unified_google_tests/test_vertex_anthropic.py b/tests/unified_google_tests/test_vertex_anthropic.py index b71b2e2e8ce..1e34a41a55b 100644 --- a/tests/unified_google_tests/test_vertex_anthropic.py +++ b/tests/unified_google_tests/test_vertex_anthropic.py @@ -5,6 +5,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Union import pytest from unittest.mock import MagicMock, AsyncMock, patch +import httpx sys.path.insert( 0, os.path.abspath("../../..") @@ -20,7 +21,7 @@ from litellm.types.utils import StandardLoggingPayload -def vertex_anthropic_mock_response(*args, **kwargs): +async def vertex_anthropic_mock_response(*args, **kwargs): """Mock response for vertex AI anthropic call""" mock_response = MagicMock() mock_response.status_code = 200 @@ -46,7 +47,6 @@ def vertex_anthropic_mock_response(*args, **kwargs): @pytest.mark.asyncio async def test_vertex_anthropic_mocked(): """Test agenerate_content with mocked HTTP calls to validate URL and request body""" - from litellm.llms.custom_httpx.llm_http_handler import AsyncHTTPHandler # Set up test data contents = ContentDict( @@ -58,32 +58,28 @@ async def test_vertex_anthropic_mocked(): role="user", ) - # Create HTTP client and mock response - client = AsyncHTTPHandler() - httpx_response = AsyncMock() - httpx_response.side_effect = vertex_anthropic_mock_response - # Expected values for validation expected_url = "https://us-east5-aiplatform.googleapis.com/v1/projects/internal-litellm-local-dev/locations/us-east5/publishers/anthropic/models/claude-sonnet-4:rawPredict" expected_body_keys = {"messages", "anthropic_version", "max_tokens"} expected_message_content = "Hello, can you tell me a short joke?" - # Patch the HTTP client and make the call - with patch.object(client, "post", new=httpx_response) as mock_call: + # Patch the AsyncHTTPHandler.post method at the module level + with patch('litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post', new_callable=AsyncMock) as mock_post: + mock_post.return_value = await vertex_anthropic_mock_response() + response = await agenerate_content( contents=contents, model="vertex_ai/claude-sonnet-4", vertex_location="us-east5", vertex_project="internal-litellm-local-dev", custom_llm_provider="vertex_ai", - client=client, ) # Verify the call was made - assert mock_call.call_count == 1 + assert mock_post.call_count == 1 # Get the call arguments - call_args = mock_call.call_args + call_args = mock_post.call_args call_kwargs = call_args.kwargs if call_args else {} # Extract URL (could be in args[0] or kwargs['url']) @@ -145,12 +141,13 @@ async def test_vertex_anthropic_mocked(): print(f"Response: {response}") -def vertex_anthropic_streaming_mock_response(*args, **kwargs): - """Mock streaming response for vertex AI anthropic call""" +class MockAsyncStreamResponse: + """Mock async streaming response that mimics httpx streaming response""" - def create_streaming_response(): - """Generator that simulates streaming chunks""" - chunks = [ + def __init__(self): + self.status_code = 200 + self.headers = {"Content-Type": "text/event-stream"} + self._chunks = [ { "type": "message_start", "message": { @@ -192,23 +189,26 @@ def create_streaming_response(): "type": "message_stop" } ] - - for chunk in chunks: - # Convert to bytes as streaming responses typically return bytes + + async def aiter_bytes(self, chunk_size=1024): + """Async iterator for response bytes""" + for chunk in self._chunks: yield f"data: {json.dumps(chunk)}\n\n".encode() - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"Content-Type": "text/event-stream"} - mock_response.iter_bytes = lambda chunk_size=1024: create_streaming_response() - mock_response.aiter_bytes = lambda chunk_size=1024: create_streaming_response() - return mock_response + async def aiter_lines(self): + """Async iterator for response lines (required by anthropic handler)""" + for chunk in self._chunks: + yield f"data: {json.dumps(chunk)}\n\n" + + +async def vertex_anthropic_streaming_mock_response(*args, **kwargs): + """Mock streaming response for vertex AI anthropic call""" + return MockAsyncStreamResponse() @pytest.mark.asyncio async def test_vertex_anthropic_streaming_mocked(): """Test agenerate_content_stream with mocked HTTP calls to validate URL and request body""" - from litellm.llms.custom_httpx.llm_http_handler import AsyncHTTPHandler # Set up test data contents = ContentDict( @@ -220,32 +220,28 @@ async def test_vertex_anthropic_streaming_mocked(): role="user", ) - # Create HTTP client and mock response - client = AsyncHTTPHandler() - httpx_response = AsyncMock() - httpx_response.side_effect = vertex_anthropic_streaming_mock_response - # Expected values for validation (same as non-streaming) expected_url = "https://us-east5-aiplatform.googleapis.com/v1/projects/internal-litellm-local-dev/locations/us-east5/publishers/anthropic/models/claude-sonnet-4:streamRawPredict" expected_body_keys = {"messages", "anthropic_version", "max_tokens"} expected_message_content = "Hello, can you tell me a short joke?" - # Patch the HTTP client and make the call - with patch.object(client, "post", new=httpx_response) as mock_call: + # Patch the AsyncHTTPHandler.post method at the module level + with patch('litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post', new_callable=AsyncMock) as mock_post: + mock_post.return_value = await vertex_anthropic_streaming_mock_response() + response_stream = await agenerate_content_stream( contents=contents, model="vertex_ai/claude-sonnet-4", vertex_location="us-east5", vertex_project="internal-litellm-local-dev", custom_llm_provider="vertex_ai", - client=client, ) # Verify the call was made - assert mock_call.call_count == 1 + assert mock_post.call_count == 1 # Get the call arguments - call_args = mock_call.call_args + call_args = mock_post.call_args call_kwargs = call_args.kwargs if call_args else {} # Extract URL (could be in args[0] or kwargs['url'])