From 512e7c7cb18984c5a6fe7c4716adfb409ecc3646 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 22 Apr 2024 12:17:38 -0700 Subject: [PATCH 01/11] Working on request_options --- google/generativeai/answer.py | 6 +- google/generativeai/discuss.py | 14 ++-- google/generativeai/embedding.py | 16 ++--- google/generativeai/generative_models.py | 10 +-- google/generativeai/models.py | 21 +++--- google/generativeai/retriever.py | 17 ++--- google/generativeai/text.py | 13 ++-- google/generativeai/types/retriever_types.py | 67 ++++++++++---------- 8 files changed, 82 insertions(+), 82 deletions(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index bbadc76c7..87828348e 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -26,12 +26,10 @@ get_default_generative_client, get_default_generative_async_client, ) -from google.generativeai import string_utils from google.generativeai.types import model_types -from google.generativeai import models +from google.generativeai.types import helper_types from google.generativeai.types import safety_types from google.generativeai.types import content_types -from google.generativeai.types import answer_types from google.generativeai.types import retriever_types from google.generativeai.types.retriever_types import MetadataFilter @@ -247,7 +245,7 @@ def generate_answer( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): f""" Calls the GenerateAnswer API and returns a `types.Answer` containing the response. diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 0cc342096..bb7abf3a6 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -26,10 +26,12 @@ from google.generativeai.client import get_default_discuss_async_client from google.generativeai import string_utils from google.generativeai.types import discuss_types +from google.generativeai.types import helper_types from google.generativeai.types import model_types from google.generativeai.types import safety_types + def _make_message(content: discuss_types.MessageOptions) -> glm.Message: """Creates a `glm.Message` object from the provided content.""" if isinstance(content, glm.Message): @@ -316,7 +318,7 @@ def chat( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: """Calls the API and returns a `types.ChatResponse` containing the response. @@ -416,7 +418,7 @@ async def chat_async( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: request = _make_generate_message_request( model=model, @@ -469,7 +471,7 @@ def last(self, message: discuss_types.MessageOptions): def reply( self, message: discuss_types.MessageOptions, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceAsyncClient): raise TypeError(f"reply can't be called on an async client, use reply_async instead.") @@ -537,7 +539,7 @@ def _build_chat_response( def _generate_response( request: glm.GenerateMessageRequest, client: glm.DiscussServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: request_options = {} @@ -553,7 +555,7 @@ def _generate_response( async def _generate_response_async( request: glm.GenerateMessageRequest, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: request_options = {} @@ -574,7 +576,7 @@ def count_message_tokens( messages: discuss_types.MessagesOptions | None = None, model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.TokenCount: model = model_types.make_model_name(model) prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages) diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 375d5dcb4..14fff1737 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -14,8 +14,6 @@ # limitations under the License. from __future__ import annotations -import dataclasses -from collections.abc import Iterable, Sequence, Mapping import itertools from typing import Any, Iterable, overload, TypeVar, Union, Mapping @@ -24,7 +22,7 @@ from google.generativeai.client import get_default_generative_client from google.generativeai.client import get_default_generative_async_client -from google.generativeai import string_utils +from google.generativeai.types import helper_types from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai.types import content_types @@ -104,7 +102,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -116,7 +114,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -127,7 +125,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create embeddings for content passed in. @@ -224,7 +222,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -236,7 +234,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -247,7 +245,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """The async version of `genai.embed_content`.""" model = model_types.make_model_name(model) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 02cab0b29..ef4f41e1e 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -15,9 +15,9 @@ import google.api_core.exceptions from google.ai import generativelanguage as glm from google.generativeai import client -from google.generativeai import string_utils from google.generativeai.types import content_types from google.generativeai.types import generation_types +from google.generativeai.types import helper_types from google.generativeai.types import safety_types @@ -181,7 +181,7 @@ def generate_content( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.GenerateContentResponse: """A multipurpose function to generate responses from the model. @@ -281,7 +281,7 @@ async def generate_content_async( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `GenerativeModel.generate_content`.""" request = self._prepare_request( @@ -323,7 +323,7 @@ async def generate_content_async( def count_tokens( self, contents: content_types.ContentsType, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> glm.CountTokensResponse: if request_options is None: request_options = {} @@ -339,7 +339,7 @@ def count_tokens( async def count_tokens_async( self, contents: content_types.ContentsType, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> glm.CountTokensResponse: if request_options is None: request_options = {} diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 7c7b8a5cf..00c51d24f 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -21,6 +21,7 @@ from google.generativeai import operations from google.generativeai.client import get_default_model_client from google.generativeai.types import model_types +from google.generativeai.types import helper_types from google.api_core import operation from google.api_core import protobuf_helpers from google.protobuf import field_mask_pb2 @@ -31,7 +32,7 @@ def get_model( name: model_types.AnyModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType| None = None, ) -> model_types.Model | model_types.TunedModel: """Given a model name, fetch the `types.Model` or `types.TunedModel` object. @@ -62,7 +63,7 @@ def get_base_model( name: model_types.BaseModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType| None = None, ) -> model_types.Model: """Get the `types.Model` for the given base model name. @@ -99,7 +100,7 @@ def get_tuned_model( name: model_types.TunedModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType| None = None, ) -> model_types.TunedModel: """Get the `types.TunedModel` for the given tuned model name. @@ -162,7 +163,7 @@ def list_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType| None = None, ) -> model_types.ModelsIterable: """Lists available models. @@ -196,7 +197,7 @@ def list_tuned_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType| None = None, ) -> model_types.TunedModelsIterable: """Lists available models. @@ -244,7 +245,7 @@ def create_tuned_model( input_key: str = "text_input", output_key: str = "output", client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType| None = None, ) -> operations.CreateTunedModelOperation: """Launches a tuning job to create a TunedModel. @@ -357,7 +358,7 @@ def update_tuned_model( updates: None = None, *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType| None = None, ) -> model_types.TunedModel: pass @@ -368,7 +369,7 @@ def update_tuned_model( updates: dict[str, Any], *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType| None = None, ) -> model_types.TunedModel: pass @@ -378,7 +379,7 @@ def update_tuned_model( updates: dict[str, Any] | None = None, *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: """Push updates to the tuned model. Only certain attributes are updatable.""" if request_options is None: @@ -436,7 +437,7 @@ def _apply_update(thing, path, value): def delete_tuned_model( tuned_model: model_types.TunedModelNameOptions, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> None: if request_options is None: request_options = {} diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index dfd5e9026..190a222a6 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -23,6 +23,7 @@ from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client +from google.generativeai.types import helper_types from google.generativeai.types.model_types import idecode_time from google.generativeai.types import retriever_types @@ -31,7 +32,7 @@ def create_corpus( name: str | None = None, display_name: str | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """ Create a new `Corpus` in the retriever service, and return it as a `retriever_types.Corpus` instance. @@ -78,7 +79,7 @@ async def create_corpus_async( name: str | None = None, display_name: str | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """This is the async version of `retriever.create_corpus`.""" if request_options is None: @@ -106,7 +107,7 @@ async def create_corpus_async( def get_corpus( name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """ Fetch a specific `Corpus` from the retriever service. @@ -139,7 +140,7 @@ def get_corpus( async def get_corpus_async( name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """This is the async version of `retriever.get_corpus`.""" if request_options is None: @@ -164,7 +165,7 @@ def delete_corpus( name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """ Delete a `Corpus` from the service. @@ -191,7 +192,7 @@ async def delete_corpus_async( name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """This is the async version of `retriever.delete_corpus`.""" if request_options is None: @@ -211,7 +212,7 @@ def list_corpora( *, page_size: Optional[int] = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[retriever_types.Corpus]: """ List the Corpuses you own in the service. @@ -242,7 +243,7 @@ async def list_corpora_async( *, page_size: Optional[int] = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[retriever_types.Corpus]: """This is the async version of `retriever.list_corpora`.""" if request_options is None: diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 3a147f945..2f3da6842 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -23,6 +23,7 @@ from google.generativeai.client import get_default_text_client from google.generativeai import string_utils +from google.generativeai.types import helper_types from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai import models @@ -141,7 +142,7 @@ def generate_text( safety_settings: safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, client: glm.TextServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.Completion: """Calls the API and returns a `types.Completion` containing the response. @@ -217,7 +218,7 @@ def __init__(self, **kwargs): def _generate_response( request: glm.GenerateTextRequest, client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ Generates a response using the provided `glm.GenerateTextRequest` and client. @@ -253,7 +254,7 @@ def count_text_tokens( model: model_types.AnyModelNameOptions, prompt: str, client: glm.TextServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.TokenCount: base_model = models.get_base_model_name(model) @@ -276,7 +277,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: str, client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -285,7 +286,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: Sequence[str], client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -293,7 +294,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: str | Sequence[str], client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create an embedding for the text passed in. diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index e1c29042b..bb5085b5d 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -29,8 +29,7 @@ from google.generativeai.client import get_dafault_permission_client from google.generativeai.client import get_dafault_permission_async_client from google.generativeai import string_utils -from google.generativeai.types import safety_types -from google.generativeai.types import citation_types +from google.generativeai.types import helper_types from google.generativeai.types import permission_types from google.generativeai.types.model_types import idecode_time from google.generativeai.utils import flatten_update_paths @@ -261,7 +260,7 @@ def create_document( display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ Request to create a `Document`. @@ -312,7 +311,7 @@ async def create_document_async( display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.create_document`.""" if request_options is None: @@ -346,7 +345,7 @@ def get_document( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ Get information about a specific `Document`. @@ -375,7 +374,7 @@ async def get_document_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.get_document`.""" if request_options is None: @@ -401,7 +400,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified `Corpus`. @@ -439,7 +438,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.update`.""" if request_options is None: @@ -470,7 +469,7 @@ def query( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """ Query a corpus for information. @@ -524,7 +523,7 @@ async def query_async( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """This is the async version of `Corpus.query`.""" if request_options is None: @@ -566,7 +565,7 @@ def delete_document( name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Delete a document in the corpus. @@ -593,7 +592,7 @@ async def delete_document_async( name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.delete_document`.""" if request_options is None: @@ -612,7 +611,7 @@ def list_documents( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Document]: """ List documents in corpus. @@ -642,7 +641,7 @@ async def list_documents_async( self, page_size: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Document]: """This is the async version of `Corpus.list_documents`.""" if request_options is None: @@ -812,7 +811,7 @@ def create_chunk( name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """ Create a `Chunk` object which has textual data. @@ -869,7 +868,7 @@ async def create_chunk_async( name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """This is the async version of `Document.create_chunk`.""" if request_options is None: @@ -968,7 +967,7 @@ def batch_create_chunks( self, chunks: BatchCreateChunkOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Create chunks within the given document. @@ -994,7 +993,7 @@ async def batch_create_chunks_async( self, chunks: BatchCreateChunkOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_create_chunk`.""" if request_options is None: @@ -1011,7 +1010,7 @@ def get_chunk( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Get information about a specific chunk. @@ -1040,7 +1039,7 @@ async def get_chunk_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.get_chunk`.""" if request_options is None: @@ -1060,7 +1059,7 @@ def list_chunks( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Chunk]: """ List chunks of a document. @@ -1086,7 +1085,7 @@ async def list_chunks_async( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Chunk]: """This is the async version of `Document.list_chunks`.""" if request_options is None: @@ -1105,7 +1104,7 @@ def query( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """ Query a `Document` in the `Corpus` for information. @@ -1158,7 +1157,7 @@ async def query_async( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """This is the async version of `Document.query`.""" if request_options is None: @@ -1205,7 +1204,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified document. @@ -1242,7 +1241,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.update`.""" if request_options is None: @@ -1270,7 +1269,7 @@ def batch_update_chunks( self, chunks: BatchUpdateChunksOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update multiple chunks within the same document. @@ -1367,7 +1366,7 @@ async def batch_update_chunks_async( self, chunks: BatchUpdateChunksOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_update_chunks`.""" if request_options is None: @@ -1455,7 +1454,7 @@ def delete_chunk( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, # fmt: {} + request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """ Delete a `Chunk`. @@ -1480,7 +1479,7 @@ async def delete_chunk_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, # fmt: {} + request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """This is the async version of `Document.delete_chunk`.""" if request_options is None: @@ -1499,7 +1498,7 @@ def batch_delete_chunks( self, chunks: BatchDeleteChunkOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Delete multiple `Chunk`s from a document. @@ -1532,7 +1531,7 @@ async def batch_delete_chunks_async( self, chunks: BatchDeleteChunkOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_delete_chunks`.""" if request_options is None: @@ -1638,7 +1637,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified `Chunk`. @@ -1687,7 +1686,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Chunk.update`.""" if request_options is None: From 511c6aa8cee2ab42f10b44347b8f141d8e1289ec Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 26 Apr 2024 15:35:39 -0700 Subject: [PATCH 02/11] Add helper_types Change-Id: Idc3e813616413f4ce085c05b771c0127e4dfc886 --- google/generativeai/types/helper_types.py | 53 +++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 google/generativeai/types/helper_types.py diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py new file mode 100644 index 000000000..4904875e0 --- /dev/null +++ b/google/generativeai/types/helper_types.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import google.api_core.timeout +import google.api_core.retry + +import dataclasses + +from typing_extensions import TypedDict, Union + + +class RequestOptionsDict(TypedDict, total=False): + retry: google.api_core.retry.Retry + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout + + +@dataclasses.dataclass +class RequestOptions: + retry: google.api_core.retry.Retry | None + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None + + def to_dict(self): + result = {} + + retry = self.retry + if retry is not None: + result['retry'] = retry + timeout = self.timeout + if timeout is not None: + result['timeout'] = timeout + + return result + + +RequestOptionsType = Union[RequestOptions, RequestOptionsDict] + +def echo (request_options): + if isinstance(request_options, RequestOptions): + return request_options.to_dict() + elif isinstance(request_options, dict): + return request_options \ No newline at end of file From e0c781c247741cc3700b249715182783afb7c294 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 26 Apr 2024 15:37:42 -0700 Subject: [PATCH 03/11] format Change-Id: I186e015de97ceece56ee5a97f6edef47ef223d18 --- google/generativeai/discuss.py | 1 - google/generativeai/models.py | 16 ++++++++-------- google/generativeai/types/helper_types.py | 9 +++++---- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index bb7abf3a6..1a4345550 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -31,7 +31,6 @@ from google.generativeai.types import safety_types - def _make_message(content: discuss_types.MessageOptions) -> glm.Message: """Creates a `glm.Message` object from the provided content.""" if isinstance(content, glm.Message): diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 00c51d24f..92f9c27be 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -32,7 +32,7 @@ def get_model( name: model_types.AnyModelNameOptions, *, client=None, - request_options: helper_types.RequestOptionsType| None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model | model_types.TunedModel: """Given a model name, fetch the `types.Model` or `types.TunedModel` object. @@ -63,7 +63,7 @@ def get_base_model( name: model_types.BaseModelNameOptions, *, client=None, - request_options: helper_types.RequestOptionsType| None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model: """Get the `types.Model` for the given base model name. @@ -100,7 +100,7 @@ def get_tuned_model( name: model_types.TunedModelNameOptions, *, client=None, - request_options: helper_types.RequestOptionsType| None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: """Get the `types.TunedModel` for the given tuned model name. @@ -163,7 +163,7 @@ def list_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: helper_types.RequestOptionsType| None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.ModelsIterable: """Lists available models. @@ -197,7 +197,7 @@ def list_tuned_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: helper_types.RequestOptionsType| None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModelsIterable: """Lists available models. @@ -245,7 +245,7 @@ def create_tuned_model( input_key: str = "text_input", output_key: str = "output", client: glm.ModelServiceClient | None = None, - request_options: helper_types.RequestOptionsType| None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> operations.CreateTunedModelOperation: """Launches a tuning job to create a TunedModel. @@ -358,7 +358,7 @@ def update_tuned_model( updates: None = None, *, client: glm.ModelServiceClient | None = None, - request_options: helper_types.RequestOptionsType| None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -369,7 +369,7 @@ def update_tuned_model( updates: dict[str, Any], *, client: glm.ModelServiceClient | None = None, - request_options: helper_types.RequestOptionsType| None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py index 4904875e0..481569dc9 100644 --- a/google/generativeai/types/helper_types.py +++ b/google/generativeai/types/helper_types.py @@ -36,18 +36,19 @@ def to_dict(self): retry = self.retry if retry is not None: - result['retry'] = retry + result["retry"] = retry timeout = self.timeout if timeout is not None: - result['timeout'] = timeout + result["timeout"] = timeout return result RequestOptionsType = Union[RequestOptions, RequestOptionsDict] -def echo (request_options): + +def echo(request_options): if isinstance(request_options, RequestOptions): return request_options.to_dict() elif isinstance(request_options, dict): - return request_options \ No newline at end of file + return request_options From c2aaad46c745f269a2cd9495a67db4cd860dd26d Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 26 Apr 2024 16:10:19 -0700 Subject: [PATCH 04/11] UpdateRequestOptions Change-Id: I9f92466967fb1aa605d442cb143699da4308409b --- google/generativeai/models.py | 2 ++ google/generativeai/types/__init__.py | 1 + google/generativeai/types/helper_types.py | 34 +++++++++++------------ 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 92f9c27be..bc517d7cb 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -345,6 +345,7 @@ def create_tuned_model( top_k=top_k, tuning_task=tuning_task, ) + operation = client.create_tuned_model( dict(tuned_model_id=id, tuned_model=tuned_model), **request_options ) @@ -396,6 +397,7 @@ def update_tuned_model( "`updates` must be a `dict`.\n" f"got: {type(updates)}" ) + tuned_model = client.get_tuned_model(name=name, **request_options) updates = flatten_update_paths(updates) diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index dc0a76761..21768bbe6 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -19,6 +19,7 @@ from google.generativeai.types.discuss_types import * from google.generativeai.types.file_types import * from google.generativeai.types.generation_types import * +from google.generativeai.types.helper_types import * from google.generativeai.types.model_types import * from google.generativeai.types.safety_types import * from google.generativeai.types.text_types import * diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py index 481569dc9..764cc3753 100644 --- a/google/generativeai/types/helper_types.py +++ b/google/generativeai/types/helper_types.py @@ -16,10 +16,13 @@ import google.api_core.timeout import google.api_core.retry +import collections import dataclasses from typing_extensions import TypedDict, Union +__all__ = ["RequestOptions", "RequestOptionsType"] + class RequestOptionsDict(TypedDict, total=False): retry: google.api_core.retry.Retry @@ -27,28 +30,25 @@ class RequestOptionsDict(TypedDict, total=False): @dataclasses.dataclass -class RequestOptions: +class RequestOptions(collections.abc.Mapping): retry: google.api_core.retry.Retry | None timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None - def to_dict(self): - result = {} + # Inherit from Mapping for **unpacking + def __getitem__(self, item): + if item == "retry": + return self.retry + elif item == "timeout": + return self.timeout + else: + raise KeyError(f'RequestOptions does not have a "{item}" key') - retry = self.retry - if retry is not None: - result["retry"] = retry - timeout = self.timeout - if timeout is not None: - result["timeout"] = timeout + def __iter__(self): + yield "retry" + yield "timeout" - return result + def __len__(self): + return 2 RequestOptionsType = Union[RequestOptions, RequestOptionsDict] - - -def echo(request_options): - if isinstance(request_options, RequestOptions): - return request_options.to_dict() - elif isinstance(request_options, dict): - return request_options From 424f922125bffb5a82dee6f7957a2491b440fff0 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 26 Apr 2024 16:54:32 -0700 Subject: [PATCH 05/11] Add docs Change-Id: I209b2b2ad8d783001b1828cbcac84ca301c11bec --- google/generativeai/types/helper_types.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py index 764cc3753..998e381e9 100644 --- a/google/generativeai/types/helper_types.py +++ b/google/generativeai/types/helper_types.py @@ -31,6 +31,24 @@ class RequestOptionsDict(TypedDict, total=False): @dataclasses.dataclass class RequestOptions(collections.abc.Mapping): + """Request options + + >>> import google.generativeai as genai + >>> from google.generativeai.types import RequestOptions + >>> from google.api_core import retry + >>> + >>> model = genai.GenerativeModel() + >>> response = model.generate_content('Hello', + ... request_options=RequestOptions( + ... retry=retry.Retry(initial=10, multiplier=2, maximum=60, timeout=300))) + >>> response = model.generate_content('Hello', + ... request_options=RequestOptions(timeout=600))) + + Args: + retry: Refer to [retry docs](https://googleapis.dev/python/google-api-core/latest/retry.html) for details. + timeout: In seconds (or provide a [TimeToDeadlineTimeout](https://googleapis.dev/python/google-api-core/latest/timeout.html) object). + """ + retry: google.api_core.retry.Retry | None timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None From fd78b30bc4874ede25efda397ca50062d30e4992 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 26 Apr 2024 17:00:40 -0700 Subject: [PATCH 06/11] work Change-Id: I00a2e2edb1e9bf3d4f51c0a868a34e044be3c6ff --- google/generativeai/answer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 87828348e..ff01e5198 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -318,6 +318,7 @@ async def generate_answer_async( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Calls the API and returns a `types.Answer` containing the answer. @@ -352,6 +353,6 @@ async def generate_answer_async( answer_style=answer_style, ) - response = await client.generate_answer(request) + response = await client.generate_answer(request, **request_options) return response From 25ea4168e8782953223cbddf5b7d0863b1c53e36 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 26 Apr 2024 17:09:27 -0700 Subject: [PATCH 07/11] Fix Py3.9 Change-Id: I8cf0ccac90ba3c4548e7549fec7d0b9b58925e7e --- google/generativeai/types/helper_types.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py index 998e381e9..4cb42458b 100644 --- a/google/generativeai/types/helper_types.py +++ b/google/generativeai/types/helper_types.py @@ -19,14 +19,15 @@ import collections import dataclasses -from typing_extensions import TypedDict, Union +from typing import Union +from typing_extensions import TypedDict __all__ = ["RequestOptions", "RequestOptionsType"] class RequestOptionsDict(TypedDict, total=False): retry: google.api_core.retry.Retry - timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout + timeout: Union[int, float, google.api_core.timeout.TimeToDeadlineTimeout] @dataclasses.dataclass From 78b74072c9fcabd975415c91a39e97e69dd53a9d Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 26 Apr 2024 17:19:21 -0700 Subject: [PATCH 08/11] use RequestOptions in tests Change-Id: I92b68bc86330ad874c3765f428a2e64ba220750f --- google/generativeai/types/helper_types.py | 11 ++++++++++- tests/test_answer.py | 3 ++- tests/test_discuss.py | 1 + tests/test_models.py | 3 ++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py index 4cb42458b..71040d4f7 100644 --- a/google/generativeai/types/helper_types.py +++ b/google/generativeai/types/helper_types.py @@ -30,7 +30,7 @@ class RequestOptionsDict(TypedDict, total=False): timeout: Union[int, float, google.api_core.timeout.TimeToDeadlineTimeout] -@dataclasses.dataclass +@dataclasses.dataclass(init=False) class RequestOptions(collections.abc.Mapping): """Request options @@ -53,6 +53,15 @@ class RequestOptions(collections.abc.Mapping): retry: google.api_core.retry.Retry | None timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None + def __init__( + self, + *, + retry: google.api_core.retry.Retry | None = None, + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None = None, + ): + self.retry = retry + self.timeout = timeout + # Inherit from Mapping for **unpacking def __getitem__(self, item): if item == "retry": diff --git a/tests/test_answer.py b/tests/test_answer.py index 6fa12603c..4128567f4 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -21,6 +21,7 @@ import google.ai.generativelanguage as glm from google.generativeai import answer +from google.generativeai import types as genai_types from google.generativeai import client from absl.testing import absltest from absl.testing import parameterized @@ -239,7 +240,7 @@ def test_generate_answer(self): def test_generate_answer_called_with_request_options(self): self.client.generate_answer = mock.MagicMock() request = mock.ANY - request_options = {"timeout": 120} + request_options = genai_types.RequestOptions(timeout=120) answer.generate_answer(contents=[], inline_passages=[], request_options=request_options) diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 9d628a42c..a53e2efa8 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -21,6 +21,7 @@ from google.generativeai import discuss from google.generativeai import client +from google.generativeai import types import google.generativeai as genai from google.generativeai.types import safety_types diff --git a/tests/test_models.py b/tests/test_models.py index e971ef86d..f39ed3a2c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,6 +31,7 @@ from google.generativeai import models from google.generativeai import client from google.generativeai.types import model_types +from google.generativeai import types as genai_types import pandas as pd @@ -470,7 +471,7 @@ def test_get_model_called_with_request_options(self): def test_get_tuned_model_called_with_request_options(self): self.client.get_tuned_model = unittest.mock.MagicMock() name = unittest.mock.ANY - request_options = {"timeout": 120} + request_options = genai_types.RequestOptions(timeout=120) try: models.get_model(name="tunedModels/", request_options=request_options) From 4d653f8a5266418912e9a4a378cb02f5b240b69b Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 26 Apr 2024 17:25:50 -0700 Subject: [PATCH 09/11] annotations Change-Id: Idbc428075729255d66d2ba8b3bcce0a1d6e8f048 --- google/generativeai/types/helper_types.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py index 71040d4f7..3eba4d3f9 100644 --- a/google/generativeai/types/helper_types.py +++ b/google/generativeai/types/helper_types.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import google.api_core.timeout import google.api_core.retry From 11afaa3fd44e2a89f0dbc02ab18aed3f8a60a50b Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 8 May 2024 05:38:10 -0700 Subject: [PATCH 10/11] Update tests/test_discuss.py Co-authored-by: Mark McDonald --- tests/test_discuss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_discuss.py b/tests/test_discuss.py index a53e2efa8..9d628a42c 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -21,7 +21,6 @@ from google.generativeai import discuss from google.generativeai import client -from google.generativeai import types import google.generativeai as genai from google.generativeai.types import safety_types From 418a38c4a2ab921c0ab084157bfa54f08c96a574 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 8 May 2024 06:48:44 -0700 Subject: [PATCH 11/11] tests Change-Id: Ife30e2cc47bd4c52d2dddafdd85a51df0e42e160 --- tests/test_discuss.py | 1 - tests/test_helpers.py | 83 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 tests/test_helpers.py diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 9d628a42c..e7efd5ef2 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -from typing import Any import unittest.mock diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 000000000..0c2de7f29 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib +import copy +import collections +from typing import Union + +from absl.testing import parameterized + +import google.ai.generativelanguage as glm + +from google.generativeai import client +from google.generativeai import models +from google.generativeai.types import model_types +from google.generativeai.types import helper_types + +from google.api_core import retry + + +class MockModelClient: + def __init__(self, test): + self.test = test + + def get_model( + self, + request: Union[glm.GetModelRequest, None] = None, + *, + name=None, + timeout=None, + retry=None + ) -> glm.Model: + if request is None: + request = glm.GetModelRequest(name=name) + self.test.assertIsInstance(request, glm.GetModelRequest) + self.test.observed_requests.append(request) + self.test.observed_timeout.append(timeout) + self.test.observed_retry.append(retry) + response = copy.copy(self.test.responses["get_model"]) + return response + + +class HelperTests(parameterized.TestCase): + + def setUp(self): + self.client = MockModelClient(self) + client._client_manager.clients["model"] = self.client + + self.observed_requests = [] + self.observed_retry = [] + self.observed_timeout = [] + self.responses = collections.defaultdict(list) + + @parameterized.named_parameters( + ["None", None, None, None], + ["Empty", {}, None, None], + ["Timeout", {"timeout": 7}, 7, None], + ["Retry", {"retry": retry.Retry(timeout=7)}, None, retry.Retry(timeout=7)], + [ + "RequestOptions", + helper_types.RequestOptions(timeout=7, retry=retry.Retry(multiplier=3)), + 7, + retry.Retry(multiplier=3), + ], + ) + def test_get_model(self, request_options, expected_timeout, expected_retry): + self.responses = {"get_model": glm.Model(name="models/fake-bison-001")} + + _ = models.get_model("models/fake-bison-001", request_options=request_options) + + self.assertEqual(self.observed_timeout[0], expected_timeout) + self.assertEqual(str(self.observed_retry[0]), str(expected_retry))