Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,15 @@ def __init__(
@classmethod
def get_config(cls):
return super().get_config()

def _supports_penalty_parameters(self, model: str) -> bool:
unsupported_models = ["gemini-2.5-pro-preview-06-05"]

for pattern in unsupported_models:
if model in pattern:
return False

return True

def get_supported_openai_params(self, model: str) -> List[str]:
supported_params = [
Expand All @@ -229,8 +238,6 @@ def get_supported_openai_params(self, model: str) -> List[str]:
"response_format",
"n",
"stop",
"frequency_penalty",
"presence_penalty",
"extra_headers",
"seed",
"logprobs",
Expand All @@ -239,6 +246,11 @@ def get_supported_openai_params(self, model: str) -> List[str]:
"parallel_tool_calls",
"web_search_options",
]

# Add penalty parameters only for non-preview models
if not self._supports_penalty_parameters(model):
supported_params.extend(["frequency_penalty", "presence_penalty"])

if supports_reasoning(model):
supported_params.append("reasoning_effort")
supported_params.append("thinking")
Expand Down Expand Up @@ -680,8 +692,12 @@ def map_openai_params( # noqa: PLR0915
)
elif param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if self._supports_penalty_parameters(model):
optional_params["frequency_penalty"] = value
elif param == "presence_penalty":
optional_params["presence_penalty"] = value
if self._supports_penalty_parameters(model):
optional_params["presence_penalty"] = value
elif param == "logprobs":
optional_params["responseLogprobs"] = value
elif param == "top_logprobs":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1151,3 +1151,57 @@ def test_vertex_ai_map_google_maps_tool_with_location():
assert retrieval_config["latLng"]["latitude"] == 37.7749
assert retrieval_config["latLng"]["longitude"] == -122.4194
assert retrieval_config["languageCode"] == "en_US"

def test_vertex_ai_penalty_parameters_validation():
"""
Test that penalty parameters are properly validated for different Gemini models.

This test ensures that:
1. Models that don't support penalty parameters (like preview models) filter them out
2. Models that support penalty parameters include them in the request
3. Appropriate warnings are logged for unsupported models
"""
v = VertexGeminiConfig()

# Test cases: (model_name, should_support_penalty_params)
test_cases = [
("gemini-2.5-pro-preview-06-05", False), # Preview model - should not support
]

for model, should_support in test_cases:
# Test _supports_penalty_parameters method
assert v._supports_penalty_parameters(model) == should_support, \
f"Model {model} penalty support should be {should_support}"

# Test get_supported_openai_params method
supported_params = v.get_supported_openai_params(model)
has_penalty_params = "frequency_penalty" in supported_params and "presence_penalty" in supported_params
assert has_penalty_params == should_support, \
f"Model {model} should {'include' if should_support else 'exclude'} penalty params in supported list"

# Test parameter mapping for unsupported model
model = "gemini-2.5-pro-preview-06-05"
non_default_params = {
"temperature": 0.7,
"frequency_penalty": 0.5,
"presence_penalty": 0.3,
"max_tokens": 100
}

optional_params = {}
result = v.map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=False
)

# Penalty parameters should be filtered out for unsupported models
assert "frequency_penalty" not in result, "frequency_penalty should be filtered out for unsupported model"
assert "presence_penalty" not in result, "presence_penalty should be filtered out for unsupported model"

# Other parameters should still be included
assert "temperature" in result, "temperature should still be included"
assert "max_output_tokens" in result, "max_output_tokens should still be included"
assert result["temperature"] == 0.7
assert result["max_output_tokens"] == 100