From f9782007c58ee11fe276c373d8d7ac6c2b0cb249 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Fri, 29 Sep 2023 13:33:59 -0700 Subject: [PATCH] feat: add Model Garden support to vertexai.preview.from_pretrained PiperOrigin-RevId: 569576160 --- .../system/aiplatform/test_language_models.py | 23 ++- tests/unit/aiplatform/test_language_models.py | 88 ++++++++++ tests/unit/aiplatform/test_vision_models.py | 62 +++++-- tests/unit/vertexai/test_model_utils.py | 155 +++++++++++++++++ .../_model_garden/_model_garden_models.py | 162 ++++++++++++++---- .../preview/_workflow/shared/model_utils.py | 147 +++++++++++++--- vertexai/preview/vision_models.py | 5 + vertexai/vision_models/_vision_models.py | 9 + 8 files changed, 581 insertions(+), 70 deletions(-) diff --git a/tests/system/aiplatform/test_language_models.py b/tests/system/aiplatform/test_language_models.py index f7ed055fd3..3245a5c2cf 100644 --- a/tests/system/aiplatform/test_language_models.py +++ b/tests/system/aiplatform/test_language_models.py @@ -24,10 +24,13 @@ from google.cloud.aiplatform.compat.types import ( job_state as gca_job_state, ) +import vertexai from tests.system.aiplatform import e2e_base from google.cloud.aiplatform.utils import gcs_utils from vertexai import language_models -from vertexai.preview import language_models as preview_language_models +from vertexai.preview import ( + language_models as preview_language_models, +) from vertexai.preview.language_models import ( ChatModel, InputOutputTextPair, @@ -87,6 +90,24 @@ def test_text_generation_streaming(self): ): assert response.text + def test_preview_text_embedding_top_level_from_pretrained(self): + aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) + + model = vertexai.preview.from_pretrained( + foundation_model_name="google/text-bison@001" + ) + + assert model.predict( + "What is the best recipe for banana bread? Recipe:", + max_output_tokens=128, + temperature=0.0, + top_p=1.0, + top_k=5, + stop_sequences=["# %%"], + ).text + + assert isinstance(model, preview_language_models.TextEmbeddingModel) + def test_chat_on_chat_model(self): aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 3032c19c10..fa5f358262 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -58,6 +58,7 @@ model as gca_model, ) +import vertexai from vertexai.preview import ( language_models as preview_language_models, ) @@ -2598,6 +2599,93 @@ def test_batch_prediction_for_text_embedding(self): model_parameters={}, ) + def test_text_generation_top_level_from_pretrained_preview(self): + """Tests the text generation model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = vertexai.preview.from_pretrained( + foundation_model_name="text-bison@001" + ) + + assert isinstance(model, preview_language_models.TextGenerationModel) + + mock_get_publisher_model.assert_called_with( + name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY + ) + assert mock_get_publisher_model.call_count == 1 + + assert ( + model._model_resource_name + == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001" + ) + + # Test that methods on TextGenerationModel still work as expected + gca_predict_response = gca_prediction_service.PredictResponse() + gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response, + ): + response = model.predict( + "What is the best recipe for banana bread? Recipe:", + max_output_tokens=128, + temperature=0.0, + top_p=1.0, + top_k=5, + ) + + assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + assert ( + response.raw_prediction_response.predictions[0] + == _TEST_TEXT_GENERATION_PREDICTION + ) + assert ( + response.safety_attributes["Violent"] + == _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0] + ) + + def test_text_embedding_top_level_from_pretrained_preview(self): + """Tests the text embedding model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = vertexai.preview.from_pretrained( + foundation_model_name="textembedding-gecko@001" + ) + + assert isinstance(model, preview_language_models.TextEmbeddingModel) + + assert ( + model._endpoint_name + == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001" + ) + + mock_get_publisher_model.assert_called_with( + name="publishers/google/models/textembedding-gecko@001", + retry=base._DEFAULT_RETRY, + ) + + assert mock_get_publisher_model.call_count == 1 + # TODO (b/285946649): add more test coverage before public preview release @pytest.mark.usefixtures("google_auth_mock") diff --git a/tests/unit/aiplatform/test_vision_models.py b/tests/unit/aiplatform/test_vision_models.py index 314915d806..80c366188a 100644 --- a/tests/unit/aiplatform/test_vision_models.py +++ b/tests/unit/aiplatform/test_vision_models.py @@ -39,8 +39,11 @@ from google.cloud.aiplatform.compat.types import ( publisher_model as gca_publisher_model, ) +import vertexai from vertexai import vision_models as ga_vision_models -from vertexai.preview import vision_models +from vertexai.preview import ( + vision_models as preview_vision_models, +) from PIL import Image as PIL_Image import pytest @@ -121,12 +124,12 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]: def generate_image_from_file( width: int = 100, height: int = 100 -) -> vision_models.Image: +) -> ga_vision_models.Image: with tempfile.TemporaryDirectory() as temp_dir: image_path = os.path.join(temp_dir, "image.png") pil_image = PIL_Image.new(mode="RGB", size=(width, height)) pil_image.save(image_path, format="PNG") - return vision_models.Image.load_from_file(image_path) + return ga_vision_models.Image.load_from_file(image_path) @pytest.mark.usefixtures("google_auth_mock") @@ -140,7 +143,7 @@ def setup_method(self): def teardown_method(self): initializer.global_pool.shutdown(wait=True) - def _get_image_generation_model(self) -> vision_models.ImageGenerationModel: + def _get_image_generation_model(self) -> preview_vision_models.ImageGenerationModel: """Gets the image generation model.""" aiplatform.init( project=_TEST_PROJECT, @@ -153,7 +156,7 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel: _IMAGE_GENERATION_PUBLISHER_MODEL_DICT ), ) as mock_get_publisher_model: - model = vision_models.ImageGenerationModel.from_pretrained( + model = preview_vision_models.ImageGenerationModel.from_pretrained( "imagegeneration@002" ) @@ -164,6 +167,34 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel: return model + def _get_preview_image_generation_model_top_level_from_pretrained( + self, + ) -> preview_vision_models.ImageGenerationModel: + """Gets the image generation model from the top-level vertexai.preview.from_pretrained method.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _IMAGE_GENERATION_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = vertexai.preview.from_pretrained( + foundation_model_name="imagegeneration@002" + ) + + mock_get_publisher_model.assert_called_with( + name="publishers/google/models/imagegeneration@002", + retry=base._DEFAULT_RETRY, + ) + + assert mock_get_publisher_model.call_count == 1 + + return model + def test_from_pretrained(self): model = self._get_image_generation_model() assert ( @@ -171,6 +202,13 @@ def test_from_pretrained(self): == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002" ) + def test_top_level_from_pretrained_preview(self): + model = self._get_preview_image_generation_model_top_level_from_pretrained() + assert ( + model._endpoint_name + == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002" + ) + def test_generate_images(self): """Tests the image generation model.""" model = self._get_image_generation_model() @@ -238,7 +276,7 @@ def test_generate_images(self): with tempfile.TemporaryDirectory() as temp_dir: image_path = os.path.join(temp_dir, "image.png") image_response[0].save(location=image_path) - image1 = vision_models.GeneratedImage.load_from_file(image_path) + image1 = preview_vision_models.GeneratedImage.load_from_file(image_path) # assert image1._pil_image.size == (width, height) assert image1.generation_parameters assert image1.generation_parameters["prompt"] == prompt1 @@ -247,7 +285,7 @@ def test_generate_images(self): mask_path = os.path.join(temp_dir, "mask.png") mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size) mask_pil_image.save(mask_path, format="PNG") - mask_image = vision_models.Image.load_from_file(mask_path) + mask_image = preview_vision_models.Image.load_from_file(mask_path) # Test generating image from base image with mock.patch.object( @@ -408,7 +446,7 @@ def test_upscale_image_on_provided_image(self): assert image_upscale_parameters["mode"] == "upscale" assert upscaled_image._image_bytes - assert isinstance(upscaled_image, vision_models.GeneratedImage) + assert isinstance(upscaled_image, preview_vision_models.GeneratedImage) def test_upscale_image_raises_if_not_1024x1024(self): """Tests image upscaling on generated images.""" @@ -457,7 +495,7 @@ def test_get_captions(self): image_path = os.path.join(temp_dir, "image.png") pil_image = PIL_Image.new(mode="RGB", size=(100, 100)) pil_image.save(image_path, format="PNG") - image = vision_models.Image.load_from_file(image_path) + image = preview_vision_models.Image.load_from_file(image_path) with mock.patch.object( target=prediction_service_client.PredictionServiceClient, @@ -544,7 +582,7 @@ def test_image_embedding_model_with_only_image(self): _IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT ), ) as mock_get_publisher_model: - model = vision_models.MultiModalEmbeddingModel.from_pretrained( + model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained( "multimodalembedding@001" ) @@ -583,7 +621,7 @@ def test_image_embedding_model_with_image_and_text(self): _IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT ), ): - model = vision_models.MultiModalEmbeddingModel.from_pretrained( + model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained( "multimodalembedding@001" ) @@ -715,7 +753,7 @@ def test_get_captions(self): image_path = os.path.join(temp_dir, "image.png") pil_image = PIL_Image.new(mode="RGB", size=(100, 100)) pil_image.save(image_path, format="PNG") - image = vision_models.Image.load_from_file(image_path) + image = preview_vision_models.Image.load_from_file(image_path) with mock.patch.object( target=prediction_service_client.PredictionServiceClient, diff --git a/tests/unit/vertexai/test_model_utils.py b/tests/unit/vertexai/test_model_utils.py index 7c0ebf548a..77aeed3283 100644 --- a/tests/unit/vertexai/test_model_utils.py +++ b/tests/unit/vertexai/test_model_utils.py @@ -30,6 +30,16 @@ custom_job as gca_custom_job, io as gca_io, ) +from google.cloud.aiplatform.compat.services import ( + model_garden_service_client, + model_service_client, +) +from google.cloud.aiplatform.compat.types import ( + deployed_model_ref_v1, + model as gca_model, + publisher_model as gca_publisher_model, +) +from vertexai.preview import language_models import pytest import cloudpickle @@ -58,6 +68,28 @@ # customJob constants _TEST_CUSTOM_JOB_RESOURCE_NAME = "projects/123/locations/us-central1/customJobs/456" +# Tuned model constants +_TEST_ID = "123456789" +_TEST_TUNED_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ID}" +) +_TEST_TUNED_MODEL_ENDPOINT_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}" +) + +_TEXT_BISON_PUBLISHER_MODEL_DICT = { + "name": "publishers/google/models/text-bison", + "version_id": "001", + "open_source_category": "PROPRIETARY", + "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA, + "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/text-bison@001", + "predict_schemata": { + "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml", + "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_generation_1.0.0.yaml", + "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/text_generation_1.0.0.yaml", + }, +} + @pytest.fixture def mock_serialize_model(): @@ -81,6 +113,7 @@ def mock_vertex_model_invalid(): model = mock.MagicMock(aiplatform.Model) model.uri = _TEST_MODEL_GCS_URI model.container_spec.image_uri = "us-docker.xxx/sklearn-cpu.1-0:latest" + model.labels = {} yield model @@ -256,6 +289,59 @@ def mock_get_custom_job_failed(): yield mock_get_custom_job +@pytest.fixture +def get_model_with_tuned_version_label_mock(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name="test-display-name", + name=_TEST_TUNED_MODEL_NAME, + labels={"google-vertex-llm-tuning-base-model-id": "text-bison-001"}, + deployed_models=[ + deployed_model_ref_v1.DeployedModelRef( + endpoint=_TEST_TUNED_MODEL_ENDPOINT_NAME, + deployed_model_id=_TEST_TUNED_MODEL_NAME, + ) + ], + ) + yield get_model_mock + + +@pytest.fixture +def get_model_with_invalid_tuned_version_labels(): + with mock.patch.object( + model_service_client.ModelServiceClient, "get_model" + ) as get_model_mock: + get_model_mock.return_value = gca_model.Model( + display_name="test-display-name", + name=_TEST_TUNED_MODEL_NAME, + labels={ + "google-vertex-llm-tuning-base-model-id": "invalidlabel", + "another": "label", + }, + deployed_models=[ + deployed_model_ref_v1.DeployedModelRef( + endpoint=_TEST_TUNED_MODEL_ENDPOINT_NAME, + deployed_model_id=_TEST_TUNED_MODEL_NAME, + ) + ], + ) + yield get_model_mock + + +@pytest.fixture +def mock_get_publisher_model(): + with mock.patch.object( + model_garden_service_client.ModelGardenServiceClient, + "get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _TEXT_BISON_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + yield mock_get_publisher_model + + @pytest.mark.usefixtures("google_auth_mock") class TestModelUtils: def setup_method(self): @@ -491,3 +577,72 @@ def test_custom_job_from_pretrained_fails_on_errored_job(self): custom_job_name=_TEST_CUSTOM_JOB_RESOURCE_NAME ) assert "did not complete" in err_msg + + @pytest.mark.usefixtures( + "mock_get_publisher_model", + ) + def test_from_pretrained_with_preview_foundation_model(self): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + foundation_model = vertexai.preview.from_pretrained( + foundation_model_name="text-bison@001" + ) + assert isinstance(foundation_model, language_models._PreviewTextGenerationModel) + + @pytest.mark.usefixtures( + "get_model_with_tuned_version_label_mock", + ) + def test_from_pretrained_with_preview_tuned_mg_model( + self, mock_get_publisher_model + ): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + tuned_model = vertexai.preview.from_pretrained(model_name=_TEST_ID) + assert mock_get_publisher_model.call_count == 1 + assert isinstance(tuned_model, language_models._PreviewTextGenerationModel) + assert tuned_model._endpoint_name == _TEST_TUNED_MODEL_ENDPOINT_NAME + assert tuned_model._model_id == "publishers/google/models/text-bison@001" + + @pytest.mark.usefixtures( + "mock_get_publisher_model", + "get_model_with_invalid_tuned_version_labels", + ) + def test_from_pretrained_raises_on_invalid_model_registry_model(self): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + with pytest.raises(ValueError): + vertexai.preview.from_pretrained(model_name=_TEST_ID) + + def test_from_pretrained_raises_with_more_than_one_arg(self): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + with pytest.raises(ValueError): + vertexai.preview.from_pretrained( + model_name=_TEST_ID, foundation_model_name="text-bison@001" + ) + + def test_from_pretrained_raises_with_no_args_passed(self): + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_BUCKET, + ) + + with pytest.raises(ValueError): + vertexai.preview.from_pretrained() diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index a32a5d2b7b..277a91169b 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -24,6 +24,9 @@ from google.cloud.aiplatform import models as aiplatform_models from google.cloud.aiplatform import _publisher_models +# this is needed for class registration to _SUBCLASSES +import vertexai # pylint:disable=unused-import + from google.cloud.aiplatform.compat.types import ( publisher_model as gca_publisher_model, ) @@ -56,6 +59,36 @@ T = TypeVar("T", bound="_ModelGardenModel") +# When this module is initialized, _SUBCLASSES contains a mapping of SDK class to the Model Garden instance for that class. +# The key is the SDK class since multiple classes can share a schema URI (i.e. _PreviewTextGenerationModel and TextGenerationModel) +# For example: {": gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml"} +_SUBCLASSES = {} + + +def _get_model_class_from_schema_uri( + schema_uri: str, +) -> "_ModelGardenModel": + """Gets the _ModelGardenModel class for the provided PublisherModel schema uri. + + Args: + schema_uri (str): The schema_uri for the provided PublisherModel, for example: + "gs://google-cloud-aiplatform/schema/predict/instance/text_generation_1.0.0.yaml" + + Returns: + The _ModelGardenModel class associated with the provided schema uri. + + Raises: + ValueError + If the provided PublisherModel schema_uri isn't supported by the SDK in Preview. + """ + + for sdk_class in _SUBCLASSES: + class_schema_uri = _SUBCLASSES[sdk_class] + if class_schema_uri == schema_uri and "preview" in sdk_class.__module__: + return sdk_class + + raise ValueError("This model is not supported in Preview by the Vertex SDK.") + @dataclasses.dataclass class _ModelInfo: @@ -67,7 +100,11 @@ class _ModelInfo: def _get_model_info( - model_id: str, schema_to_class_map: Dict[str, "_ModelGardenModel"] + model_id: str, + schema_to_class_map: Optional[Dict[str, "_ModelGardenModel"]] = None, + interface_class: Optional[Type["_ModelGardenModel"]] = None, + publisher_model_res: Optional[_publisher_models._PublisherModel] = None, + tuned_vertex_model: Optional[aiplatform.Model] = None, ) -> _ModelInfo: """Gets the model information by model ID. @@ -92,11 +129,12 @@ def _get_model_info( if "/" not in model_id: model_id = "publishers/google/models/" + model_id - publisher_model_res = ( - _publisher_models._PublisherModel( # pylint: disable=protected-access - resource_name=model_id - )._gca_resource - ) + if not publisher_model_res: + publisher_model_res = ( + _publisher_models._PublisherModel( # pylint: disable=protected-access + resource_name=model_id + )._gca_resource + ) if not publisher_model_res.name.startswith("publishers/google/models/"): raise ValueError( @@ -113,10 +151,18 @@ def _get_model_info( f"The model does not have an associated Publisher Model. {publisher_model_res.name}" ) - endpoint_name = publisher_model_template.format( - project=aiplatform_initializer.global_config.project, - location=aiplatform_initializer.global_config.location, - ) + if not tuned_vertex_model: + endpoint_name = publisher_model_template.format( + project=aiplatform_initializer.global_config.project, + location=aiplatform_initializer.global_config.location, + ) + else: + tuned_model_deployments = tuned_vertex_model.gca_resource.deployed_models + if len(tuned_model_deployments) == 0: + # Deploying the model + endpoint_name = tuned_vertex_model.deploy().resource_name + else: + endpoint_name = tuned_model_deployments[0].endpoint if short_model_id in _SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP: tuning_pipeline_uri = _SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP[short_model_id] @@ -125,15 +171,16 @@ def _get_model_info( tuning_pipeline_uri = None tuning_model_id = None - interface_class = schema_to_class_map.get( - publisher_model_res.predict_schemata.instance_schema_uri - ) - - if not interface_class: - raise ValueError( - f"Unknown model {publisher_model_res.name}; {schema_to_class_map}" + if schema_to_class_map: + interface_class = schema_to_class_map.get( + publisher_model_res.predict_schemata.instance_schema_uri ) + if not interface_class: + raise ValueError( + f"Unknown model {publisher_model_res.name}; {schema_to_class_map}" + ) + return _ModelInfo( endpoint_name=endpoint_name, interface_class=interface_class, @@ -143,6 +190,62 @@ def _get_model_info( ) +def _from_pretrained( + *, + interface_class: Optional[Type[T]] = None, + model_name: Optional[str] = None, + publisher_model: Optional[_publisher_models._PublisherModel] = None, + tuned_vertex_model: Optional[aiplatform.Model] = None, +) -> T: + """Loads a _ModelGardenModel. + + Args: + model_name: Name of the model. + + Returns: + An instance of a class derieved from `_ModelGardenModel`. + + Raises: + ValueError: If model_name is unknown. + ValueError: If model does not support this class. + """ + if interface_class: + if not interface_class._INSTANCE_SCHEMA_URI: + raise ValueError( + f"Class {interface_class} is not a correct model interface class since it does not have an instance schema URI." + ) + + model_info = _get_model_info( + model_id=model_name, + schema_to_class_map={interface_class._INSTANCE_SCHEMA_URI: interface_class}, + ) + + else: + schema_uri = publisher_model._gca_resource.predict_schemata.instance_schema_uri + interface_class = _get_model_class_from_schema_uri(schema_uri) + + model_info = _get_model_info( + model_id=model_name, + interface_class=interface_class, + publisher_model_res=publisher_model._gca_resource, + tuned_vertex_model=tuned_vertex_model, + ) + + if not issubclass(model_info.interface_class, interface_class): + raise ValueError( + f"{model_name} is of type {model_info.interface_class.__name__} not of type {interface_class.__name__}" + ) + + interface_class._validate_launch_stage( + interface_class, model_info.publisher_model_resource + ) + + return model_info.interface_class( + model_id=model_name, + endpoint_name=model_info.endpoint_name, + ) + + class _ModelGardenModel: """Base class for shared methods and properties across Model Garden models.""" @@ -171,6 +274,10 @@ def _validate_launch_stage( # Subclasses override this attribute to specify their instance schema _INSTANCE_SCHEMA_URI: Optional[str] = None + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + _SUBCLASSES[cls] = cls._INSTANCE_SCHEMA_URI + def __init__(self, model_id: str, endpoint_name: Optional[str] = None): """Creates a _ModelGardenModel. @@ -206,23 +313,4 @@ def from_pretrained(cls: Type[T], model_name: str) -> T: ValueError: If model does not support this class. """ - if not cls._INSTANCE_SCHEMA_URI: - raise ValueError( - f"Class {cls} is not a correct model interface class since it does not have an instance schema URI." - ) - - model_info = _get_model_info( - model_id=model_name, schema_to_class_map={cls._INSTANCE_SCHEMA_URI: cls} - ) - - if not issubclass(model_info.interface_class, cls): - raise ValueError( - f"{model_name} is of type {model_info.interface_class.__name__} not of type {cls.__name__}" - ) - - cls._validate_launch_stage(cls, model_info.publisher_model_resource) - - return model_info.interface_class( - model_id=model_name, - endpoint_name=model_info.endpoint_name, - ) + return _from_pretrained(interface_class=cls, model_name=model_name) diff --git a/vertexai/preview/_workflow/shared/model_utils.py b/vertexai/preview/_workflow/shared/model_utils.py index 858c43486d..663a4740b3 100644 --- a/vertexai/preview/_workflow/shared/model_utils.py +++ b/vertexai/preview/_workflow/shared/model_utils.py @@ -22,7 +22,8 @@ """ import os -from typing import Any, Optional, Union +import re +from typing import Any, Dict, Optional, Union from google.cloud import aiplatform from google.cloud.aiplatform import base @@ -34,6 +35,16 @@ any_serializer, serializers_base, ) + +# These need to be imported to be included in _ModelGardenModel.__init_subclass__ +from vertexai.language_models import ( + _language_models, +) # pylint:disable=unused-import +from vertexai.vision_models import ( + _vision_models, +) # pylint:disable=unused-import +from vertexai._model_garden import _model_garden_models +from google.cloud.aiplatform import _publisher_models from vertexai.preview._workflow.executor import training from google.cloud.aiplatform.compat.types import job_state as gca_job_state @@ -236,6 +247,61 @@ def _register_pytorch_model( return vertex_model +def _get_publisher_model_resource( + short_model_name: str, +) -> _publisher_models._PublisherModel: + """Gets the PublisherModel resource from the short model name. + + Args: + short_model_name (str): + Required. The short name for the model, for example 'text-bison@001' + + Returns: + A _PublisherModel instance pointing to the PublisherModel resource for + this model. + + Raises: + ValueError: + If no PublisherModel resource was found for the given short_model_name. + """ + + if "/" not in short_model_name: + short_model_name = "publishers/google/models/" + short_model_name + + try: + publisher_model_resource = _publisher_models._PublisherModel( + resource_name=short_model_name + ) + return publisher_model_resource + except: # noqa: E722 + raise ValueError("Please provide a valid Model Garden model resource.") + + +def _check_from_pretrained_passed_exactly_one_arg(fn_args: Dict[str, Any]) -> None: + """Checks exactly one argument was passed to from_pretrained. + + This supports an expanding number of arguments added to from_pretrained. + + Args: + fn_args (Dict[str, Any]): + Required. A dictionary of the arguments passed to from_pretrained. + + Raises: + ValueError: + If more than one arg or no args were passed to from_pretrained. + """ + + passed_args = 0 + + for _, argval in fn_args.items(): + if argval is not None: + passed_args += 1 + if passed_args != 1: + raise ValueError( + f"Exactly one of {list(fn_args.keys())} must be provided to from_pretrained." + ) + + def register( model: Union[ "sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module" # noqa: F821 @@ -299,6 +365,7 @@ def from_pretrained( *, model_name: Optional[str] = None, custom_job_name: Optional[str] = None, + foundation_model_name: Optional[str] = None, ) -> Union["sklearn.base.BaseEstimator", "tf.Module", "torch.nn.Module"]: # noqa: F821 """Pulls a model from Model Registry or from a CustomJob ID for retraining. @@ -309,13 +376,16 @@ def from_pretrained( model_name (str): Optional. The resource ID or fully qualified resource name of a registered model. Format: "12345678910" or - "projects/123/locations/us-central1/models/12345678910@1". One of `model_name` or - `custom_job_name` is required. + "projects/123/locations/us-central1/models/12345678910@1". One of `model_name`, + `custom_job_name`, or `foundation_model_name` is required. custom_job_name (str): Optional. The resource ID or fully qualified resource name of a CustomJob created with Vertex SDK remote training. If the job has completed successfully, this will load - the trained model created in the CustomJob. One of `model_name` or - `custom_job_name` is required. + the trained model created in the CustomJob. One of `model_name`, `custom_job_name`, or + `foundation_model_name` is required. + foundation_model_name (str): + Optional. The name of the foundation model to load. For example: "text-bison@001". One of + `model_name`,`custom_job_name`, or `foundation_model_name` is required. Returns: model: local model for uptraining. @@ -326,8 +396,7 @@ def from_pretrained( If custom job was not created with Vertex SDK remote training If both or neither model_name or custom_job_name are provided """ - if not model_name and not custom_job_name or (model_name and custom_job_name): - raise ValueError("Exactly one of `model` or `custom_job` should be provided.") + _check_from_pretrained_passed_exactly_one_arg(locals()) project = vertexai.preview.global_config.project location = vertexai.preview.global_config.location @@ -338,24 +407,56 @@ def from_pretrained( vertex_model = aiplatform.Model( model_name, project=project, location=location, credentials=credentials ) - if vertex_model.labels.get("registered_by_vertex_ai") != "true": - raise ValueError( - f"The model {model_name} is not registered through `vertex_ai.register`." + if vertex_model.labels.get("registered_by_vertex_ai") == "true": + + artifact_uri = vertex_model.uri + model_file = _get_model_file_from_image_uri( + vertex_model.container_spec.image_uri ) - artifact_uri = vertex_model.uri - model_file = _get_model_file_from_image_uri( - vertex_model.container_spec.image_uri - ) + serializer = any_serializer.AnySerializer() + model = serializer.deserialize(os.path.join(artifact_uri, model_file)) - serializer = any_serializer.AnySerializer() - model = serializer.deserialize(os.path.join(artifact_uri, model_file)) + rewrapper = serializer.deserialize( + os.path.join(artifact_uri, _REWRAPPER_NAME) + ) - rewrapper = serializer.deserialize(os.path.join(artifact_uri, _REWRAPPER_NAME)) + # Rewrap model (in-place) for following remote training. + rewrapper(model) + return model - # Rewrap model (in-place) for following remote training. - rewrapper(model) - return model + elif not vertex_model.labels: + raise ValueError( + f"The model {model_name} was not registered through `vertexai.preview.register` or created from Model Garden." + ) + else: + # Get the labels and check if it's a tuned model from a PublisherModel resource + for label_key in vertex_model.labels: + publisher_model_label = vertex_model.labels.get(label_key) + publisher_model_label_format_match = r"(^[a-z]+-[a-z]+-[0-9]{3}$)" + + if re.match(publisher_model_label_format_match, publisher_model_label): + # This try/except ensures this method will iterate over all models in a label even + # if one fails on PublisherModel resource creation + short_model_id = ( + _language_models._get_model_id_from_tuning_model_id( + publisher_model_label + ) + ) + + try: + publisher_model = _get_publisher_model_resource(short_model_id) + return _model_garden_models._from_pretrained( + model_name=short_model_id, + publisher_model=publisher_model, + tuned_vertex_model=vertex_model, + ) + + except ValueError: + continue + raise ValueError( + f"The model {model_name} was not created from a Model Garden model." + ) if custom_job_name: job = aiplatform.CustomJob.get( @@ -380,3 +481,9 @@ def from_pretrained( raise ValueError( "The provided job did not complete successfully. Please provide a pending or successful customJob ID." ) + + if foundation_model_name: + publisher_model = _get_publisher_model_resource(foundation_model_name) + return _model_garden_models._from_pretrained( + model_name=foundation_model_name, publisher_model=publisher_model + ) diff --git a/vertexai/preview/vision_models.py b/vertexai/preview/vision_models.py index 67290e6736..07c8fe03c2 100644 --- a/vertexai/preview/vision_models.py +++ b/vertexai/preview/vision_models.py @@ -15,22 +15,27 @@ """Classes for working with vision models.""" from vertexai.vision_models._vision_models import ( + _PreviewImageTextModel, Image, ImageGenerationModel, ImageGenerationResponse, ImageCaptioningModel, ImageQnAModel, + ImageTextModel, GeneratedImage, MultiModalEmbeddingModel, MultiModalEmbeddingResponse, ) +ImageTextModel = _PreviewImageTextModel + __all__ = [ "Image", "ImageGenerationModel", "ImageGenerationResponse", "ImageCaptioningModel", "ImageQnAModel", + "ImageTextModel", "GeneratedImage", "MultiModalEmbeddingModel", "MultiModalEmbeddingResponse", diff --git a/vertexai/vision_models/_vision_models.py b/vertexai/vision_models/_vision_models.py index 5cd38ee00f..4cd838d4f2 100644 --- a/vertexai/vision_models/_vision_models.py +++ b/vertexai/vision_models/_vision_models.py @@ -709,6 +709,8 @@ class ImageTextModel(ImageCaptioningModel, ImageQnAModel): ) """ + __module__ = "vertexai.vision_models" + # NOTE: Using this ImageTextModel class is recommended over using ImageQnAModel or ImageCaptioningModel, # since SDK Model Garden classes should follow the design pattern of exactly 1 SDK class to 1 Model Garden schema URI @@ -716,3 +718,10 @@ class ImageTextModel(ImageCaptioningModel, ImageQnAModel): _LAUNCH_STAGE = ( _model_garden_models._SDK_GA_LAUNCH_STAGE # pylint: disable=protected-access ) + + +class _PreviewImageTextModel(ImageTextModel): + + __module__ = "vertexai.preview.vision_models" + + _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE