Skip to content

Commit

Permalink
feat: add Model Garden support to vertexai.preview.from_pretrained
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569576160
  • Loading branch information
sararob authored and copybara-github committed Sep 29, 2023
1 parent 1aab6fd commit f978200
Show file tree
Hide file tree
Showing 8 changed files with 581 additions and 70 deletions.
23 changes: 22 additions & 1 deletion tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
from google.cloud.aiplatform.compat.types import (
job_state as gca_job_state,
)
import vertexai
from tests.system.aiplatform import e2e_base
from google.cloud.aiplatform.utils import gcs_utils
from vertexai import language_models
from vertexai.preview import language_models as preview_language_models
from vertexai.preview import (
language_models as preview_language_models,
)
from vertexai.preview.language_models import (
ChatModel,
InputOutputTextPair,
Expand Down Expand Up @@ -87,6 +90,24 @@ def test_text_generation_streaming(self):
):
assert response.text

def test_preview_text_embedding_top_level_from_pretrained(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = vertexai.preview.from_pretrained(
foundation_model_name="google/text-bison@001"
)

assert model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0.0,
top_p=1.0,
top_k=5,
stop_sequences=["# %%"],
).text

assert isinstance(model, preview_language_models.TextEmbeddingModel)

def test_chat_on_chat_model(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

Expand Down
88 changes: 88 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
model as gca_model,
)

import vertexai
from vertexai.preview import (
language_models as preview_language_models,
)
Expand Down Expand Up @@ -2598,6 +2599,93 @@ def test_batch_prediction_for_text_embedding(self):
model_parameters={},
)

def test_text_generation_top_level_from_pretrained_preview(self):
"""Tests the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = vertexai.preview.from_pretrained(
foundation_model_name="text-bison@001"
)

assert isinstance(model, preview_language_models.TextGenerationModel)

mock_get_publisher_model.assert_called_with(
name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY
)
assert mock_get_publisher_model.call_count == 1

assert (
model._model_resource_name
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001"
)

# Test that methods on TextGenerationModel still work as expected
gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response,
):
response = model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0.0,
top_p=1.0,
top_k=5,
)

assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
assert (
response.raw_prediction_response.predictions[0]
== _TEST_TEXT_GENERATION_PREDICTION
)
assert (
response.safety_attributes["Violent"]
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
)

def test_text_embedding_top_level_from_pretrained_preview(self):
"""Tests the text embedding model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = vertexai.preview.from_pretrained(
foundation_model_name="textembedding-gecko@001"
)

assert isinstance(model, preview_language_models.TextEmbeddingModel)

assert (
model._endpoint_name
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001"
)

mock_get_publisher_model.assert_called_with(
name="publishers/google/models/textembedding-gecko@001",
retry=base._DEFAULT_RETRY,
)

assert mock_get_publisher_model.call_count == 1


# TODO (b/285946649): add more test coverage before public preview release
@pytest.mark.usefixtures("google_auth_mock")
Expand Down
62 changes: 50 additions & 12 deletions tests/unit/aiplatform/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@
from google.cloud.aiplatform.compat.types import (
publisher_model as gca_publisher_model,
)
import vertexai
from vertexai import vision_models as ga_vision_models
from vertexai.preview import vision_models
from vertexai.preview import (
vision_models as preview_vision_models,
)

from PIL import Image as PIL_Image
import pytest
Expand Down Expand Up @@ -121,12 +124,12 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:

def generate_image_from_file(
width: int = 100, height: int = 100
) -> vision_models.Image:
) -> ga_vision_models.Image:
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir, "image.png")
pil_image = PIL_Image.new(mode="RGB", size=(width, height))
pil_image.save(image_path, format="PNG")
return vision_models.Image.load_from_file(image_path)
return ga_vision_models.Image.load_from_file(image_path)


@pytest.mark.usefixtures("google_auth_mock")
Expand All @@ -140,7 +143,7 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
def _get_image_generation_model(self) -> preview_vision_models.ImageGenerationModel:
"""Gets the image generation model."""
aiplatform.init(
project=_TEST_PROJECT,
Expand All @@ -153,7 +156,7 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:
_IMAGE_GENERATION_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = vision_models.ImageGenerationModel.from_pretrained(
model = preview_vision_models.ImageGenerationModel.from_pretrained(
"imagegeneration@002"
)

Expand All @@ -164,13 +167,48 @@ def _get_image_generation_model(self) -> vision_models.ImageGenerationModel:

return model

def _get_preview_image_generation_model_top_level_from_pretrained(
self,
) -> preview_vision_models.ImageGenerationModel:
"""Gets the image generation model from the top-level vertexai.preview.from_pretrained method."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_IMAGE_GENERATION_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = vertexai.preview.from_pretrained(
foundation_model_name="imagegeneration@002"
)

mock_get_publisher_model.assert_called_with(
name="publishers/google/models/imagegeneration@002",
retry=base._DEFAULT_RETRY,
)

assert mock_get_publisher_model.call_count == 1

return model

def test_from_pretrained(self):
model = self._get_image_generation_model()
assert (
model._endpoint_name
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002"
)

def test_top_level_from_pretrained_preview(self):
model = self._get_preview_image_generation_model_top_level_from_pretrained()
assert (
model._endpoint_name
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002"
)

def test_generate_images(self):
"""Tests the image generation model."""
model = self._get_image_generation_model()
Expand Down Expand Up @@ -238,7 +276,7 @@ def test_generate_images(self):
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir, "image.png")
image_response[0].save(location=image_path)
image1 = vision_models.GeneratedImage.load_from_file(image_path)
image1 = preview_vision_models.GeneratedImage.load_from_file(image_path)
# assert image1._pil_image.size == (width, height)
assert image1.generation_parameters
assert image1.generation_parameters["prompt"] == prompt1
Expand All @@ -247,7 +285,7 @@ def test_generate_images(self):
mask_path = os.path.join(temp_dir, "mask.png")
mask_pil_image = PIL_Image.new(mode="RGB", size=image1._pil_image.size)
mask_pil_image.save(mask_path, format="PNG")
mask_image = vision_models.Image.load_from_file(mask_path)
mask_image = preview_vision_models.Image.load_from_file(mask_path)

# Test generating image from base image
with mock.patch.object(
Expand Down Expand Up @@ -408,7 +446,7 @@ def test_upscale_image_on_provided_image(self):
assert image_upscale_parameters["mode"] == "upscale"

assert upscaled_image._image_bytes
assert isinstance(upscaled_image, vision_models.GeneratedImage)
assert isinstance(upscaled_image, preview_vision_models.GeneratedImage)

def test_upscale_image_raises_if_not_1024x1024(self):
"""Tests image upscaling on generated images."""
Expand Down Expand Up @@ -457,7 +495,7 @@ def test_get_captions(self):
image_path = os.path.join(temp_dir, "image.png")
pil_image = PIL_Image.new(mode="RGB", size=(100, 100))
pil_image.save(image_path, format="PNG")
image = vision_models.Image.load_from_file(image_path)
image = preview_vision_models.Image.load_from_file(image_path)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
Expand Down Expand Up @@ -544,7 +582,7 @@ def test_image_embedding_model_with_only_image(self):
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = vision_models.MultiModalEmbeddingModel.from_pretrained(
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

Expand Down Expand Up @@ -583,7 +621,7 @@ def test_image_embedding_model_with_image_and_text(self):
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
),
):
model = vision_models.MultiModalEmbeddingModel.from_pretrained(
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
"multimodalembedding@001"
)

Expand Down Expand Up @@ -715,7 +753,7 @@ def test_get_captions(self):
image_path = os.path.join(temp_dir, "image.png")
pil_image = PIL_Image.new(mode="RGB", size=(100, 100))
pil_image.save(image_path, format="PNG")
image = vision_models.Image.load_from_file(image_path)
image = preview_vision_models.Image.load_from_file(image_path)

with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
Expand Down
Loading

0 comments on commit f978200

Please sign in to comment.