Skip to content

Commit

Permalink
chore: [LLM] Added system tests for tuning
Browse files Browse the repository at this point in the history
The tests cover the tuning as well as listing and loading the tuned models

PiperOrigin-RevId: 540433943
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 15, 2023
1 parent d4d8613 commit 8191035
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,66 @@ def test_text_embedding(self):
for embedding in embeddings:
vector = embedding.values
assert len(vector) == 768

def test_tuning(self, shared_state):
"""Test tuning, listing and loading models."""
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

model = TextGenerationModel.from_pretrained("google/text-bison@001")

import pandas

training_data = pandas.DataFrame(
data=[
{"input_text": "Input 0", "output_text": "Output 0"},
{"input_text": "Input 1", "output_text": "Output 1"},
{"input_text": "Input 2", "output_text": "Output 2"},
{"input_text": "Input 3", "output_text": "Output 3"},
{"input_text": "Input 4", "output_text": "Output 4"},
{"input_text": "Input 5", "output_text": "Output 5"},
{"input_text": "Input 6", "output_text": "Output 6"},
{"input_text": "Input 7", "output_text": "Output 7"},
{"input_text": "Input 8", "output_text": "Output 8"},
{"input_text": "Input 9", "output_text": "Output 9"},
]
)

model.tune_model(
training_data=training_data,
train_steps=1,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
)
# According to the Pipelines design, external resources created by a pipeline
# must not be modified or deleted. Otherwise caching will break next pipeline runs.
shared_state.setdefault("resources", [])
shared_state["resources"].append(model._endpoint)
shared_state["resources"].extend(
aiplatform.Model(model_name=deployed_model.model)
for deployed_model in model._endpoint.list_models()
)
# Deleting the Endpoint is a little less bad since the LLM SDK will recreate it, but it's not advised for the same reason.

response = model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0,
top_p=1,
top_k=5,
)
assert response.text

tuned_model_names = model.list_tuned_model_names()
assert tuned_model_names
tuned_model_name = tuned_model_names[0]

tuned_model = TextGenerationModel.get_tuned_model(tuned_model_name)

tuned_model_response = tuned_model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0,
top_p=1,
top_k=5,
)
assert tuned_model_response.text

0 comments on commit 8191035

Please sign in to comment.