-
Notifications
You must be signed in to change notification settings - Fork 348
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: GenAI - Add BatchPredictionJob for GenAI models
PiperOrigin-RevId: 633670470
- Loading branch information
1 parent
9bda328
commit df4a4f2
Showing
3 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Unit tests for generative model batch prediction.""" | ||
# pylint: disable=protected-access | ||
|
||
import pytest | ||
from unittest import mock | ||
|
||
import vertexai | ||
from google.cloud.aiplatform import base as aiplatform_base | ||
from google.cloud.aiplatform import initializer as aiplatform_initializer | ||
from google.cloud.aiplatform.compat.services import job_service_client | ||
from google.cloud.aiplatform.compat.types import ( | ||
batch_prediction_job as gca_batch_prediction_job_compat, | ||
) | ||
from google.cloud.aiplatform.compat.types import ( | ||
job_state as gca_job_state_compat, | ||
) | ||
from vertexai.preview import batch_prediction | ||
|
||
|
||
_TEST_PROJECT = "test-project" | ||
_TEST_LOCATION = "us-central1" | ||
|
||
_TEST_GEMINI_MODEL_NAME = "gemini-1.0-pro" | ||
_TEST_GEMINI_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_GEMINI_MODEL_NAME}" | ||
_TEST_PALM_MODEL_NAME = "text-bison" | ||
_TEST_PALM_MODEL_RESOURCE_NAME = f"publishers/google/models/{_TEST_PALM_MODEL_NAME}" | ||
|
||
_TEST_BATCH_PREDICTION_JOB_ID = "123456789" | ||
_TEST_BATCH_PREDICTION_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/batchPredictionJobs/{_TEST_BATCH_PREDICTION_JOB_ID}" | ||
_TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4) | ||
|
||
|
||
# TODO(b/339230025) Mock the whole service instead of methods. | ||
@pytest.fixture | ||
def get_batch_prediction_job_mock(): | ||
with mock.patch.object( | ||
job_service_client.JobServiceClient, "get_batch_prediction_job" | ||
) as get_job_mock: | ||
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( | ||
name=_TEST_BATCH_PREDICTION_JOB_NAME, | ||
model=_TEST_GEMINI_MODEL_RESOURCE_NAME, | ||
state=_TEST_JOB_STATE_SUCCESS, | ||
) | ||
yield get_job_mock | ||
|
||
|
||
@pytest.fixture | ||
def get_batch_prediction_job_invalid_model_mock(): | ||
with mock.patch.object( | ||
job_service_client.JobServiceClient, "get_batch_prediction_job" | ||
) as get_job_mock: | ||
get_job_mock.return_value = gca_batch_prediction_job_compat.BatchPredictionJob( | ||
name=_TEST_BATCH_PREDICTION_JOB_NAME, | ||
model=_TEST_PALM_MODEL_RESOURCE_NAME, | ||
state=_TEST_JOB_STATE_SUCCESS, | ||
) | ||
yield get_job_mock | ||
|
||
|
||
@pytest.mark.usefixtures("google_auth_mock") | ||
class TestBatchPredictionJob: | ||
"""Unit tests for BatchPredictionJob.""" | ||
|
||
def setup_method(self): | ||
vertexai.init( | ||
project=_TEST_PROJECT, | ||
location=_TEST_LOCATION, | ||
) | ||
|
||
def teardown_method(self): | ||
aiplatform_initializer.global_pool.shutdown(wait=True) | ||
|
||
def test_init_batch_prediction_job(self, get_batch_prediction_job_mock): | ||
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) | ||
|
||
get_batch_prediction_job_mock.assert_called_once_with( | ||
name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=aiplatform_base._DEFAULT_RETRY | ||
) | ||
|
||
@pytest.mark.usefixtures("get_batch_prediction_job_invalid_model_mock") | ||
def test_init_batch_prediction_job_invalid_model(self): | ||
with pytest.raises( | ||
ValueError, | ||
match=( | ||
f"BatchPredictionJob '{_TEST_BATCH_PREDICTION_JOB_ID}' " | ||
f"runs with the model '{_TEST_PALM_MODEL_RESOURCE_NAME}', " | ||
"which is not a GenAI model." | ||
), | ||
): | ||
batch_prediction.BatchPredictionJob(_TEST_BATCH_PREDICTION_JOB_ID) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Class to support Batch Prediction with GenAI models.""" | ||
# pylint: disable=protected-access | ||
|
||
from google.cloud.aiplatform import base as aiplatform_base | ||
from google.cloud.aiplatform import utils as aiplatform_utils | ||
|
||
|
||
_LOGGER = aiplatform_base.Logger(__name__) | ||
|
||
_GEMINI_MODEL_PREFIX = "publishers/google/models/gemini" | ||
|
||
|
||
class BatchPredictionJob(aiplatform_base._VertexAiResourceNounPlus): | ||
"""Represents a BatchPredictionJob that runs with GenAI models.""" | ||
|
||
_resource_noun = "batchPredictionJobs" | ||
_getter_method = "get_batch_prediction_job" | ||
_list_method = "list_batch_prediction_jobs" | ||
_cancel_method = "cancel_batch_prediction_job" | ||
_delete_method = "delete_batch_prediction_job" | ||
_job_type = "batch-predictions" | ||
_parse_resource_name_method = "parse_batch_prediction_job_path" | ||
_format_resource_name_method = "batch_prediction_job_path" | ||
|
||
client_class = aiplatform_utils.JobClientWithOverride | ||
|
||
def __init__(self, batch_prediction_job_name: str): | ||
"""Retrieves a BatchPredictionJob resource that runs with a GenAI model. | ||
Args: | ||
batch_prediction_job_name (str): | ||
Required. A fully-qualified BatchPredictionJob resource name or | ||
ID. Example: "projects/.../locations/.../batchPredictionJobs/456" | ||
or "456" when project and location are initialized. | ||
Raises: | ||
ValueError: If batch_prediction_job_name represents a BatchPredictionJob | ||
resource that runs with another type of model. | ||
""" | ||
super().__init__(resource_name=batch_prediction_job_name) | ||
self._gca_resource = self._get_gca_resource( | ||
resource_name=batch_prediction_job_name | ||
) | ||
# TODO(b/338452508) Support tuned GenAI models. | ||
if not self._gca_resource.model.startswith(_GEMINI_MODEL_PREFIX): | ||
raise ValueError( | ||
f"BatchPredictionJob '{batch_prediction_job_name}' " | ||
f"runs with the model '{self._gca_resource.model}', " | ||
"which is not a GenAI model." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Classes for batch prediction.""" | ||
|
||
# We just want to re-export certain classes | ||
# pylint: disable=g-multiple-import,g-importing-member | ||
from vertexai.batch_prediction._batch_prediction import ( | ||
BatchPredictionJob, | ||
) | ||
|
||
__all__ = [ | ||
"BatchPredictionJob", | ||
] |