Skip to content

Commit

Permalink
feat: LLM - Added support for async prediction methods
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 566381589
  • Loading branch information
Ark-kun authored and copybara-github committed Sep 18, 2023
1 parent 41d341e commit c9c9f10
Show file tree
Hide file tree
Showing 3 changed files with 369 additions and 17 deletions.
76 changes: 76 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

# pylint: disable=protected-access, g-multiple-import

import pytest


from google.cloud import aiplatform
from google.cloud.aiplatform.compat.types import (
job_state as gca_job_state,
Expand Down Expand Up @@ -54,6 +57,22 @@ def test_text_generation(self):
stop_sequences=["# %%"],
).text

@pytest.mark.asyncio
async def test_text_generation_model_predict_async(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextGenerationModel.from_pretrained("google/text-bison@001")

response = await model.predict_async(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0.0,
top_p=1.0,
top_k=5,
stop_sequences=["# %%"],
)
assert response.text

def test_text_generation_streaming(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

Expand Down Expand Up @@ -107,6 +126,46 @@ def test_chat_on_chat_model(self):
assert chat.message_history[2].content == message2
assert chat.message_history[3].author == chat.MODEL_AUTHOR

@pytest.mark.asyncio
async def test_chat_model_async(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

chat_model = ChatModel.from_pretrained("google/chat-bison@001")
chat = chat_model.start_chat(
context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
examples=[
InputOutputTextPair(
input_text="Who do you work for?",
output_text="I work for Ned.",
),
InputOutputTextPair(
input_text="What do I like?",
output_text="Ned likes watching movies.",
),
],
temperature=0.0,
stop_sequences=["# %%"],
)

message1 = "Are my favorite movies based on a book series?"
response1 = await chat.send_message_async(message1)
assert response1.text
assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message1
assert chat.message_history[1].author == chat.MODEL_AUTHOR

message2 = "When were these books published?"
response2 = await chat.send_message_async(
message2,
temperature=0.1,
)
assert response2.text
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message2
assert chat.message_history[3].author == chat.MODEL_AUTHOR

def test_chat_model_send_message_streaming(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

Expand Down Expand Up @@ -161,6 +220,23 @@ def test_text_embedding(self):
assert embeddings[1].statistics.token_count > 1000
assert embeddings[1].statistics.truncated

@pytest.mark.asyncio
async def test_text_embedding_async(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextEmbeddingModel.from_pretrained("google/textembedding-gecko@001")
# One short text, one llong text (to check truncation)
texts = ["What is life?", "What is life?" * 1000]
embeddings = await model.get_embeddings_async(texts)
assert len(embeddings) == 2
assert len(embeddings[0].values) == 768
assert embeddings[0].statistics.token_count > 0
assert not embeddings[0].statistics.truncated

assert len(embeddings[1].values) == 768
assert embeddings[1].statistics.token_count > 1000
assert embeddings[1].statistics.truncated

def test_tuning(self, shared_state):
"""Test tuning, listing and loading models."""
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
Expand Down
48 changes: 47 additions & 1 deletion tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
model_service_client,
pipeline_service_client,
)
from google.cloud.aiplatform.compat.services import prediction_service_client
from google.cloud.aiplatform.compat.services import (
prediction_service_client,
prediction_service_async_client,
)
from google.cloud.aiplatform.compat.types import (
artifact as gca_artifact,
prediction_service as gca_prediction_service,
Expand Down Expand Up @@ -1273,6 +1276,49 @@ def test_text_generation_ga(self):
assert "topP" not in prediction_parameters
assert "topK" not in prediction_parameters

@pytest.mark.asyncio
async def test_text_generation_async(self):
"""Tests the text generation model."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

gca_predict_response = gca_prediction_service.PredictResponse()
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)

with mock.patch.object(
target=prediction_service_async_client.PredictionServiceAsyncClient,
attribute="predict",
return_value=gca_predict_response,
) as mock_predict:
response = await model.predict_async(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0.0,
top_p=1.0,
top_k=5,
stop_sequences=["\n"],
)

prediction_parameters = mock_predict.call_args[1]["parameters"]
assert prediction_parameters["maxDecodeSteps"] == 128
assert prediction_parameters["temperature"] == 0.0
assert prediction_parameters["topP"] == 1.0
assert prediction_parameters["topK"] == 5
assert prediction_parameters["stopSequences"] == ["\n"]
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]

def test_text_generation_model_predict_streaming(self):
"""Tests the TextGenerationModel.predict_streaming method."""
with mock.patch.object(
Expand Down
Loading

0 comments on commit c9c9f10

Please sign in to comment.