diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 5056125dba..6804728b1e 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -53,7 +53,10 @@ model as gca_model, ) -from vertexai.preview import language_models +from vertexai.preview import ( + language_models as preview_language_models, +) +from vertexai import language_models from google.cloud.aiplatform_v1 import Execution as GapicExecution from google.cloud.aiplatform.compat.types import ( encryption_spec as gca_encryption_spec, @@ -456,7 +459,7 @@ def get_endpoint_mock(): @pytest.fixture def mock_get_tuned_model(get_endpoint_mock): with mock.patch.object( - language_models.TextGenerationModel, "get_tuned_model" + preview_language_models.TextGenerationModel, "get_tuned_model" ) as mock_text_generation_model: mock_text_generation_model._model_id = ( test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME @@ -519,6 +522,50 @@ def teardown_method(self): initializer.global_pool.shutdown(wait=True) def test_text_generation(self): + """Tests the text generation model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = preview_language_models.TextGenerationModel.from_pretrained( + "text-bison@001" + ) + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY + ) + + assert ( + model._model_resource_name + == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001" + ) + + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + response = model.predict( + "What is the best recipe for banana bread? Recipe:", + max_output_tokens=128, + temperature=0, + top_p=1, + top_k=5, + ) + + assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + + def test_text_generation_ga(self): """Tests the text generation model.""" aiplatform.init( project=_TEST_PROJECT, @@ -596,7 +643,7 @@ def test_tune_model( _TEXT_BISON_PUBLISHER_MODEL_DICT ), ): - model = language_models.TextGenerationModel.from_pretrained( + model = preview_language_models.TextGenerationModel.from_pretrained( "text-bison@001" ) @@ -631,7 +678,7 @@ def test_get_tuned_model( _TEXT_BISON_PUBLISHER_MODEL_DICT ), ): - tuned_model = language_models.TextGenerationModel.get_tuned_model( + tuned_model = preview_language_models.TextGenerationModel.get_tuned_model( test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME ) @@ -651,7 +698,7 @@ def get_tuned_model_raises_if_not_called_with_mg_model(self): ) with pytest.raises(ValueError): - language_models.TextGenerationModel.get_tuned_model( + preview_language_models.TextGenerationModel.get_tuned_model( test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME ) @@ -668,7 +715,7 @@ def test_chat(self): _CHAT_BISON_PUBLISHER_MODEL_DICT ), ) as mock_get_publisher_model: - model = language_models.ChatModel.from_pretrained("chat-bison@001") + model = preview_language_models.ChatModel.from_pretrained("chat-bison@001") mock_get_publisher_model.assert_called_once_with( name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY @@ -681,11 +728,11 @@ def test_chat(self): My favorite movies are Lord of the Rings and Hobbit. """, examples=[ - language_models.InputOutputTextPair( + preview_language_models.InputOutputTextPair( input_text="Who do you work for?", output_text="I work for Ned.", ), - language_models.InputOutputTextPair( + preview_language_models.InputOutputTextPair( input_text="What do I like?", output_text="Ned likes watching movies.", ), @@ -786,7 +833,7 @@ def test_code_chat(self): _CODECHAT_BISON_PUBLISHER_MODEL_DICT ), ) as mock_get_publisher_model: - model = language_models.CodeChatModel.from_pretrained( + model = preview_language_models.CodeChatModel.from_pretrained( "google/codechat-bison@001" ) @@ -882,7 +929,7 @@ def test_code_generation(self): _CODE_GENERATION_BISON_PUBLISHER_MODEL_DICT ), ) as mock_get_publisher_model: - model = language_models.CodeGenerationModel.from_pretrained( + model = preview_language_models.CodeGenerationModel.from_pretrained( "google/code-bison@001" ) @@ -909,9 +956,11 @@ def test_code_generation(self): # Validating the parameters predict_temperature = 0.1 predict_max_output_tokens = 100 - default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE + default_temperature = ( + preview_language_models.CodeGenerationModel._DEFAULT_TEMPERATURE + ) default_max_output_tokens = ( - language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS + preview_language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS ) with mock.patch.object( @@ -948,7 +997,7 @@ def test_code_completion(self): _CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT ), ) as mock_get_publisher_model: - model = language_models.CodeGenerationModel.from_pretrained( + model = preview_language_models.CodeGenerationModel.from_pretrained( "google/code-gecko@001" ) @@ -975,9 +1024,11 @@ def test_code_completion(self): # Validating the parameters predict_temperature = 0.1 predict_max_output_tokens = 100 - default_temperature = language_models.CodeGenerationModel._DEFAULT_TEMPERATURE + default_temperature = ( + preview_language_models.CodeGenerationModel._DEFAULT_TEMPERATURE + ) default_max_output_tokens = ( - language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS + preview_language_models.CodeGenerationModel._DEFAULT_MAX_OUTPUT_TOKENS ) with mock.patch.object( @@ -1002,6 +1053,43 @@ def test_code_completion(self): assert prediction_parameters["maxOutputTokens"] == default_max_output_tokens def test_text_embedding(self): + """Tests the text embedding model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = preview_language_models.TextEmbeddingModel.from_pretrained( + "textembedding-gecko@001" + ) + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/textembedding-gecko@001", + retry=base._DEFAULT_RETRY, + ) + + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.append(_TEST_TEXT_EMBEDDING_PREDICTION) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + embeddings = model.get_embeddings(["What is life?"]) + assert embeddings + for embedding in embeddings: + vector = embedding.values + assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH + assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"] + + def test_text_embedding_ga(self): """Tests the text embedding model.""" aiplatform.init( project=_TEST_PROJECT, diff --git a/tests/unit/aiplatform/test_model_garden_models.py b/tests/unit/aiplatform/test_model_garden_models.py index 014739022f..3d3288736b 100644 --- a/tests/unit/aiplatform/test_model_garden_models.py +++ b/tests/unit/aiplatform/test_model_garden_models.py @@ -18,7 +18,6 @@ import pytest from importlib import reload from unittest import mock -from typing import Dict, Type from google.cloud import aiplatform from google.cloud.aiplatform import base @@ -53,14 +52,7 @@ class TestModelGardenModels: """Unit tests for the _ModelGardenModel base class.""" class FakeModelGardenModel(_model_garden_models._ModelGardenModel): - @staticmethod - def _get_public_preview_class_map() -> Dict[ - str, Type[_model_garden_models._ModelGardenModel] - ]: - test_map = { - "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml": TestModelGardenModels.FakeModelGardenModel - } - return test_map + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml" def setup_method(self): reload(initializer) diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index cb56810004..6f39d9f1ee 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -107,7 +107,9 @@ def _get_model_info( ) if not interface_class: - raise ValueError(f"Unknown model {publisher_model_res.name}") + raise ValueError( + f"Unknown model {publisher_model_res.name}; {schema_to_class_map}" + ) return _ModelInfo( endpoint_name=endpoint_name, @@ -120,18 +122,8 @@ def _get_model_info( class _ModelGardenModel: """Base class for shared methods and properties across Model Garden models.""" - @staticmethod - @abc.abstractmethod - def _get_public_preview_class_map() -> Dict[str, Type["_ModelGardenModel"]]: - """Returns a Dict mapping schema URI to model class. - - Subclasses should implement this method. Example mapping: - - { - "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml": TextGenerationModel - } - """ - pass + # Subclasses override this attribute to specify their instance schema + _INSTANCE_SCHEMA_URI: Optional[str] = None def __init__(self, model_id: str, endpoint_name: Optional[str] = None): """Creates a _ModelGardenModel. @@ -168,8 +160,13 @@ def from_pretrained(cls, model_name: str) -> "_ModelGardenModel": ValueError: If model does not support this class. """ + if not cls._INSTANCE_SCHEMA_URI: + raise ValueError( + f"Class {cls} is not a correct model interface class since it does not have an instance schema URI." + ) + model_info = _get_model_info( - model_id=model_name, schema_to_class_map=cls._get_public_preview_class_map() + model_id=model_name, schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls} ) if not issubclass(model_info.interface_class, cls): diff --git a/vertexai/language_models/__init__.py b/vertexai/language_models/__init__.py index 4d14759da2..ecab1cf7f1 100644 --- a/vertexai/language_models/__init__.py +++ b/vertexai/language_models/__init__.py @@ -12,5 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +"""Classes for working with language models.""" -from vertexai.language_models import _language_models +from vertexai.language_models._language_models import ( + InputOutputTextPair, + TextEmbedding, + TextEmbeddingModel, + TextGenerationModel, + TextGenerationResponse, +) + +__all__ = [ + "InputOutputTextPair", + "TextEmbedding", + "TextEmbeddingModel", + "TextGenerationModel", + "TextGenerationResponse", +] diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 2f3f8ed4a4..26fdc21d3b 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -15,7 +15,7 @@ """Classes for working with language models.""" import dataclasses -from typing import Any, Dict, List, Optional, Sequence, Type, Union +from typing import Any, List, Optional, Sequence, Union from google.cloud import aiplatform from google.cloud.aiplatform import base @@ -36,20 +36,6 @@ # Endpoint label/metadata key to preserve the base model ID information _TUNING_BASE_MODEL_ID_LABEL_KEY = "google-vertex-llm-tuning-base-model-id" -_LLM_TEXT_GENERATION_INSTANCE_SCHEMA_URI = ( - "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml" -) -_LLM_CHAT_GENERATION_INSTANCE_SCHEMA_URI = ( - "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml" -) -_LLM_TEXT_EMBEDDING_INSTANCE_SCHEMA_URI = ( - "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml" -) -_LLM_CODE_CHAT_GENERATION_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml" -_LLM_CODE_GENERATION_INSTANCE_SCHEMA_URI = ( - "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml" -) - def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str: """Gets the base model ID for the model ID labels used the tuned models. @@ -89,18 +75,6 @@ def __init__(self, model_id: str, endpoint_name: Optional[str] = None): endpoint_name=endpoint_name, ) - @staticmethod - def _get_public_preview_class_map() -> Dict[str, Type["_LanguageModel"]]: - - interface_class_map = { - _LLM_TEXT_GENERATION_INSTANCE_SCHEMA_URI: _PreviewTextGenerationModel, - _LLM_CHAT_GENERATION_INSTANCE_SCHEMA_URI: ChatModel, - _LLM_TEXT_EMBEDDING_INSTANCE_SCHEMA_URI: TextEmbeddingModel, - _LLM_CODE_CHAT_GENERATION_INSTANCE_SCHEMA_URI: CodeChatModel, - _LLM_CODE_GENERATION_INSTANCE_SCHEMA_URI: CodeGenerationModel, - } - return interface_class_map - @property def _model_resource_name(self) -> str: """Full resource name of the model.""" @@ -122,12 +96,12 @@ def list_tuned_model_names(self) -> Sequence[str]: """ model_info = _model_garden_models._get_model_info( model_id=self._model_id, - schema_to_class_map=self._get_public_preview_class_map(), + schema_to_class_map={self._INSTANCE_SCHEMA_URI: type(self)}, ) return _list_tuned_model_names(model_id=model_info.tuning_model_id) - @staticmethod - def get_tuned_model(tuned_model_name: str) -> "_LanguageModel": + @classmethod + def get_tuned_model(cls, tuned_model_name: str) -> "_LanguageModel": """Loads the specified tuned language model.""" tuned_vertex_model = aiplatform.Model(tuned_model_name) @@ -150,7 +124,7 @@ def get_tuned_model(tuned_model_name: str) -> "_LanguageModel": base_model_id = _get_model_id_from_tuning_model_id(tuning_model_id) model_info = _model_garden_models._get_model_info( model_id=base_model_id, - schema_to_class_map=_LanguageModel._get_public_preview_class_map(), + schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls}, ) model = model_info.interface_class( model_id=base_model_id, @@ -198,7 +172,7 @@ def tune_model( ) model_info = _model_garden_models._get_model_info( model_id=self._model_id, - schema_to_class_map=self._get_public_preview_class_map(), + schema_to_class_map={self._INSTANCE_SCHEMA_URI: type(self)}, ) if not model_info.tuning_pipeline_uri: raise RuntimeError(f"The {self._model_id} model does not support tuning") @@ -241,6 +215,8 @@ class TextGenerationModel(_LanguageModel): model.predict("What is life?") """ + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml" + _DEFAULT_TEMPERATURE = 0.0 _DEFAULT_MAX_OUTPUT_TOKENS = 128 _DEFAULT_TOP_P = 0.95 @@ -451,6 +427,10 @@ class TextEmbeddingModel(_LanguageModel): print(len(vector)) """ + _INSTANCE_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml" + ) + def get_embeddings(self, texts: List[str]) -> List["TextEmbedding"]: instances = [{"content": str(text)} for text in texts] @@ -551,7 +531,7 @@ class ChatModel(_ChatModelBase): chat.send_message("Do you know any cool events this weekend?") """ - pass + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml" class CodeChatModel(_ChatModelBase): @@ -568,6 +548,8 @@ class CodeChatModel(_ChatModelBase): code_chat.send_message("Please help write a function to calculate the min of two numbers") """ + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/codechat_generation_1.0.0.yaml" + _DEFAULT_MAX_OUTPUT_TOKENS = 128 _DEFAULT_TEMPERATURE = 0.5 @@ -793,6 +775,8 @@ class CodeGenerationModel(_LanguageModel): )) """ + _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/code_generation_1.0.0.yaml" + _DEFAULT_TEMPERATURE = 0.0 _DEFAULT_MAX_OUTPUT_TOKENS = 128