diff --git a/google/genai/_base_url.py b/google/genai/_base_url.py new file mode 100644 index 000000000..563c22709 --- /dev/null +++ b/google/genai/_base_url.py @@ -0,0 +1,70 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from typing import Optional + +from .types import HttpOptions + +_default_base_gemini_url = None +_default_base_vertex_url = None + + +class BaseUrlParameters: + """Parameters for setting the base URLs for the Gemini API and Vertex AI API.""" + + gemini_url: str | None = None + vertex_url: str | None = None + + def __init__( + self, + gemini_url: str | None = None, + vertex_url: str | None = None, + ): + self.gemini_url = gemini_url + self.vertex_url = vertex_url + + +def set_default_base_urls(base_url_params: BaseUrlParameters) -> None: + """Overrides the base URLs for the Gemini API and Vertex AI API.""" + global _default_base_gemini_url, _default_base_vertex_url + _default_base_gemini_url = base_url_params.gemini_url + _default_base_vertex_url = base_url_params.vertex_url + + +def get_default_base_urls() -> BaseUrlParameters: + """Overrides the base URLs for the Gemini API and Vertex AI API.""" + return BaseUrlParameters( + gemini_url=_default_base_gemini_url, vertex_url=_default_base_vertex_url + ) + + +def get_base_url( + vertexai: bool, + http_options: Optional[HttpOptions] = None, +) -> str | None: + """Returns the default base URL based on the following priority. + + 1. Base URLs set via HttpOptions. + 2. Base URLs set via the latest call to setDefaultBaseUrls. + 3. Base URLs set via environment variables. + """ + if http_options and http_options.base_url: + return http_options.base_url + + if vertexai: + return _default_base_vertex_url or os.getenv('GOOGLE_VERTEX_BASE_URL') + else: + return _default_base_gemini_url or os.getenv('GOOGLE_GEMINI_BASE_URL') diff --git a/google/genai/client.py b/google/genai/client.py index 3663f0245..885429a6c 100644 --- a/google/genai/client.py +++ b/google/genai/client.py @@ -20,6 +20,7 @@ import pydantic from ._api_client import BaseApiClient +from ._base_url import get_base_url from ._replay_api_client import ReplayApiClient from .batches import AsyncBatches, Batches from .caches import AsyncCaches, Caches @@ -78,6 +79,7 @@ def live(self) -> AsyncLive: def operations(self) -> AsyncOperations: return self._operations + class DebugConfig(pydantic.BaseModel): """Configuration options that change client network behavior when testing.""" @@ -114,26 +116,28 @@ class Client: Attributes: api_key: The `API key `_ to use for authentication. Applies to the Gemini Developer API only. - vertexai: Indicates whether the client should use the Vertex AI - API endpoints. Defaults to False (uses Gemini Developer API endpoints). + vertexai: Indicates whether the client should use the Vertex AI API + endpoints. Defaults to False (uses Gemini Developer API endpoints). Applies to the Vertex AI API only. credentials: The credentials to use for authentication when calling the Vertex AI APIs. Credentials can be obtained from environment variables and - default credentials. For more information, see - `Set up Application Default Credentials + default credentials. For more information, see `Set up Application Default + Credentials `_. Applies to the Vertex AI API only. - project: The `Google Cloud project ID `_ to - use for quota. Can be obtained from environment variables (for example, + project: The `Google Cloud project ID + `_ to use + for quota. Can be obtained from environment variables (for example, ``GOOGLE_CLOUD_PROJECT``). Applies to the Vertex AI API only. - location: The `location `_ + location: The `location + `_ to send API requests to (for example, ``us-central1``). Can be obtained from environment variables. Applies to the Vertex AI API only. debug_config: Config settings that control network behavior of the client. This is typically used when running test code. http_options: Http options to use for the client. These options will be - applied to all requests made by the client. Example usage: - `client = genai.Client(http_options=types.HttpOptions(api_version='v1'))`. + applied to all requests made by the client. Example usage: `client = + genai.Client(http_options=types.HttpOptions(api_version='v1'))`. Usage for the Gemini Developer API: @@ -198,6 +202,13 @@ def __init__( if isinstance(http_options, dict): http_options = HttpOptions(**http_options) + base_url = get_base_url(vertexai, http_options) + if base_url: + if http_options: + http_options.base_url = base_url + else: + http_options = HttpOptions(base_url=base_url) + self._api_client = self._get_api_client( vertexai=vertexai, api_key=api_key, diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index 5bac08f1a..97e847ba3 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -16,15 +16,17 @@ """Tests for client initialization.""" +import logging +import os +import ssl + import certifi import google.auth from google.auth import credentials -import logging -import os import pytest -import ssl from ... import _api_client as api_client +from ... import _base_url as base_url from ... import _replay_api_client as replay_api_client from ... import Client @@ -274,6 +276,7 @@ def test_invalid_vertexai_constructor_empty(monkeypatch): monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "") monkeypatch.setenv("GOOGLE_CLOUD_LOCATION", "") monkeypatch.setenv("GOOGLE_API_KEY", "") + def mock_auth_default(scopes=None): return None, None @@ -319,10 +322,7 @@ def test_invalid_vertexai_constructor3(monkeypatch): m.delenv("GOOGLE_CLOUD_LOCATION", raising=False) project_id = "fake_project_id" with pytest.raises(ValueError): - Client( - vertexai=True, - project=project_id - ) + Client(vertexai=True, project=project_id) def test_vertexai_explicit_arg_precedence1(monkeypatch): @@ -578,7 +578,7 @@ def test_vertexai_global_endpoint(monkeypatch): def test_client_logs_to_logger_instance(monkeypatch, caplog): - caplog.set_level(logging.DEBUG, logger='google_genai._api_client') + caplog.set_level(logging.DEBUG, logger="google_genai._api_client") project_id = "fake_project_id" location = "fake-location" @@ -588,75 +588,180 @@ def test_client_logs_to_logger_instance(monkeypatch, caplog): _ = Client(vertexai=True, api_key=api_key) - assert 'INFO' in caplog.text - assert 'The user provided Vertex AI API key will take precedence' in caplog.text + assert "INFO" in caplog.text + assert ( + "The user provided Vertex AI API key will take precedence" in caplog.text + ) + def test_client_ssl_context_implicit_initialization(): client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( - api_client.HttpOptions()) + api_client.HttpOptions() + ) assert client_args["verify"] assert async_client_args["verify"] assert isinstance(client_args["verify"], ssl.SSLContext) assert isinstance(async_client_args["verify"], ssl.SSLContext) + def test_client_ssl_context_explicit_initialization_same_args(): ctx = ssl.create_default_context( - cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), - capath=os.environ.get('SSL_CERT_DIR'), + cafile=os.environ.get("SSL_CERT_FILE", certifi.where()), + capath=os.environ.get("SSL_CERT_DIR"), ) options = api_client.HttpOptions( - client_args={"verify": ctx}, async_client_args={"verify": ctx}) + client_args={"verify": ctx}, async_client_args={"verify": ctx} + ) client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( - options) + options + ) assert client_args["verify"] == ctx assert async_client_args["verify"] == ctx + def test_client_ssl_context_explicit_initialization_separate_args(): ctx = ssl.create_default_context( - cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), - capath=os.environ.get('SSL_CERT_DIR'), + cafile=os.environ.get("SSL_CERT_FILE", certifi.where()), + capath=os.environ.get("SSL_CERT_DIR"), ) async_ctx = ssl.create_default_context( - cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), - capath=os.environ.get('SSL_CERT_DIR'), + cafile=os.environ.get("SSL_CERT_FILE", certifi.where()), + capath=os.environ.get("SSL_CERT_DIR"), ) options = api_client.HttpOptions( - client_args={"verify": ctx}, async_client_args={"verify": async_ctx}) + client_args={"verify": ctx}, async_client_args={"verify": async_ctx} + ) client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( - options) + options + ) assert client_args["verify"] == ctx assert async_client_args["verify"] == async_ctx + def test_client_ssl_context_explicit_initialization_sync_args(): ctx = ssl.create_default_context( - cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), - capath=os.environ.get('SSL_CERT_DIR'), + cafile=os.environ.get("SSL_CERT_FILE", certifi.where()), + capath=os.environ.get("SSL_CERT_DIR"), ) - options = api_client.HttpOptions( - client_args={"verify": ctx}) + options = api_client.HttpOptions(client_args={"verify": ctx}) client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( - options) + options + ) assert client_args["verify"] == ctx assert async_client_args["verify"] == ctx + def test_client_ssl_context_explicit_initialization_async_args(): ctx = ssl.create_default_context( - cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), - capath=os.environ.get('SSL_CERT_DIR'), + cafile=os.environ.get("SSL_CERT_FILE", certifi.where()), + capath=os.environ.get("SSL_CERT_DIR"), ) - options = api_client.HttpOptions( - async_client_args={"verify": ctx}) + options = api_client.HttpOptions(async_client_args={"verify": ctx}) client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( - options) + options + ) assert client_args["verify"] == ctx assert async_client_args["verify"] == ctx + + +def test_constructor_with_base_url_from_http_options(): + mldev_http_options = { + "base_url": "https://placeholder-fake-url.com/", + } + vertexai_http_options = { + "base_url": ( + "https://{self.location}-aiplatform.googleapis.com/{{api_version}}/" + ), + } + + mldev_client = Client( + api_key="google_api_key", http_options=mldev_http_options + ) + assert not mldev_client.models._api_client.vertexai + assert ( + mldev_client.models._api_client.get_read_only_http_options()["base_url"] + == "https://placeholder-fake-url.com/" + ) + + vertexai_client = Client( + vertexai=True, + project="fake_project_id", + location="fake-location", + http_options=vertexai_http_options, + ) + assert vertexai_client.models._api_client.vertexai + assert ( + vertexai_client.models._api_client.get_read_only_http_options()[ + "base_url" + ] + == "https://{self.location}-aiplatform.googleapis.com/{{api_version}}/" + ) + + +def test_constructor_with_base_url_from_set_default_base_urls(): + base_url.set_default_base_urls( + base_url.BaseUrlParameters( + gemini_url="https://gemini-base-url.com/", + vertex_url="https://vertex-base-url.com/", + ) + ) + mldev_client = Client(api_key="google_api_key") + assert not mldev_client.models._api_client.vertexai + assert ( + mldev_client.models._api_client.get_read_only_http_options()["base_url"] + == "https://gemini-base-url.com/" + ) + + vertexai_client = Client( + vertexai=True, + project="fake_project_id", + location="fake-location", + ) + assert vertexai_client.models._api_client.vertexai + assert ( + vertexai_client.models._api_client.get_read_only_http_options()[ + "base_url" + ] + == "https://vertex-base-url.com/" + ) + base_url.set_default_base_urls( + base_url.BaseUrlParameters( + gemini_url=None, + vertex_url=None, + ) + ) + + +def test_constructor_with_base_url_from_environment_variables(monkeypatch): + monkeypatch.setenv("GOOGLE_GEMINI_BASE_URL", "https://gemini-base-url.com/") + monkeypatch.setenv("GOOGLE_VERTEX_BASE_URL", "https://vertex-base-url.com/") + + mldev_client = Client(api_key="google_api_key") + assert not mldev_client.models._api_client.vertexai + assert ( + mldev_client.models._api_client.get_read_only_http_options()["base_url"] + == "https://gemini-base-url.com/" + ) + + vertexai_client = Client( + vertexai=True, + project="fake_project_id", + location="fake-location", + ) + assert vertexai_client.models._api_client.vertexai + assert ( + vertexai_client.models._api_client.get_read_only_http_options()[ + "base_url" + ] + == "https://vertex-base-url.com/" + )