Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add wait_for_resource_creation to BatchPredictionJob and unblock async creation when model is pending creation. #660

Merged
merged 10 commits into from
Aug 29, 2021
33 changes: 33 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,39 @@ Please visit `Importing models to Vertex AI`_ for a detailed overview:
.. _Importing models to Vertex AI: https://cloud.google.com/vertex-ai/docs/general/import-model


Batch Prediction
----------------

To create a batch prediction job:

.. code-block:: Python

model = aiplatform.Model('/projects/my-project/locations/us-central1/models/{MODEL_ID}')

batch_prediction_job = model.batch_predict(
job_display_name='my-batch-prediction-job',
instances_format='csv'
machine_type='n1-standard-4',
gcs_source=['gs://path/to/my/file.csv']
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/'
)

You can also create a batch prediction job asynchronously by including the `sync=False` argument:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this line be part of the code block?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not intended to be in the code block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to say that it might be better to move it after the .. code-block:: Python line.


.. code-block:: Python

batch_prediction_job = model.batch_predict(..., sync=False)

# wait for resource to be created
batch_prediction_job.wait_for_resource_creation()

# get the state
batch_prediction_job.state

# block until job is complete
batch_prediction_job.wait()


Endpoints
---------

Expand Down
16 changes: 10 additions & 6 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,17 +680,21 @@ def wrapper(*args, **kwargs):
inspect.getfullargspec(method).annotations["return"]
)

# object produced by the method
returned_object = bound_args.arguments.get(return_input_arg)

# is a classmethod that creates the object and returns it
if args and inspect.isclass(args[0]):
# assumes classmethod is our resource noun
returned_object = args[0]._empty_constructor()

# assumes class in classmethod is the resource noun
returned_object = (
args[0]._empty_constructor()
if not returned_object
else returned_object
)
self = returned_object

else: # instance method

# object produced by the method
returned_object = bound_args.arguments.get(return_input_arg)

# if we're returning an input object
if returned_object and returned_object is not self:

Expand Down
121 changes: 60 additions & 61 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import constants
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import hyperparameter_tuning
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import console_utils
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils

from google.cloud.aiplatform.compat.services import job_service_client
from google.cloud.aiplatform.compat.types import (
batch_prediction_job as gca_bp_job_compat,
batch_prediction_job_v1 as gca_bp_job_v1,
Expand All @@ -58,6 +49,13 @@
machine_resources_v1beta1 as gca_machine_resources_v1beta1,
study as gca_study_compat,
)
from google.cloud.aiplatform import constants
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import hyperparameter_tuning
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import console_utils
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils


_LOGGER = base.Logger(__name__)
Expand Down Expand Up @@ -352,7 +350,7 @@ def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
def create(
cls,
job_display_name: str,
model_name: str,
model_name: Union[str, "aiplatform.Model"],
instances_format: str = "jsonl",
predictions_format: str = "jsonl",
gcs_source: Optional[Union[str, Sequence[str]]] = None,
Expand Down Expand Up @@ -384,10 +382,12 @@ def create(
Required. The user-defined name of the BatchPredictionJob.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
model_name (str):
model_name (Union[str, aiplatform.Model]):
Required. A fully-qualified model resource name or model ID.
Example: "projects/123/locations/us-central1/models/456" or
"456" when project and location are initialized or passed.

Or an instance of aiplatform.Model.
instances_format (str):
Required. The format in which instances are given, must be one
of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
Expand Down Expand Up @@ -533,15 +533,17 @@ def create(
"""

utils.validate_display_name(job_display_name)

if labels:
utils.validate_labels(labels)

model_name = utils.full_resource_name(
resource_name=model_name,
resource_noun="models",
project=project,
location=location,
)
if isinstance(model_name, str):
model_name = utils.full_resource_name(
resource_name=model_name,
resource_noun="models",
project=project,
location=location,
)

# Raise error if both or neither source URIs are provided
if bool(gcs_source) == bool(bigquery_source):
Expand Down Expand Up @@ -570,6 +572,7 @@ def create(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)

gca_bp_job = gca_bp_job_compat
gca_io = gca_io_compat
gca_machine_resources = gca_machine_resources_compat
Expand All @@ -584,7 +587,6 @@ def create(

# Required Fields
gapic_batch_prediction_job.display_name = job_display_name
gapic_batch_prediction_job.model = model_name

input_config = gca_bp_job.BatchPredictionJob.InputConfig()
output_config = gca_bp_job.BatchPredictionJob.OutputConfig()
Expand Down Expand Up @@ -657,63 +659,43 @@ def create(
metadata=explanation_metadata, parameters=explanation_parameters
)

# TODO (b/174502913): Support private feature once released

api_client = cls._instantiate_client(location=location, credentials=credentials)
empty_batch_prediction_job = cls._empty_constructor(
project=project, location=location, credentials=credentials,
)

return cls._create(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
batch_prediction_job=gapic_batch_prediction_job,
empty_batch_prediction_job=empty_batch_prediction_job,
model_or_model_name=model_name,
gca_batch_prediction_job=gapic_batch_prediction_job,
generate_explanation=generate_explanation,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
sync=sync,
)

@classmethod
@base.optional_sync()
@base.optional_sync(return_input_arg="empty_batch_prediction_job")
def _create(
cls,
api_client: job_service_client.JobServiceClient,
parent: str,
batch_prediction_job: Union[
empty_batch_prediction_job: "BatchPredictionJob",
model_or_model_name: Union[str, "aiplatform.Model"],
gca_batch_prediction_job: Union[
gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob
],
generate_explanation: bool,
project: str,
location: str,
credentials: Optional[auth_credentials.Credentials],
sync: bool = True,
) -> "BatchPredictionJob":
"""Create a batch prediction job.

Args:
api_client (dataset_service_client.DatasetServiceClient):
Required. An instance of DatasetServiceClient with the correct api_endpoint
already set based on user's preferences.
batch_prediction_job (gca_bp_job.BatchPredictionJob):
empty_batch_prediction_job (BatchPredictionJob):
Required. BatchPredictionJob without _gca_resource populated.
model_or_model_name (Union[str, aiplatform.Model]):
Required. Required. A fully-qualified model resource name or
an instance of aiplatform.Model.
gca_batch_prediction_job (gca_bp_job.BatchPredictionJob):
Required. a batch prediction job proto for creating a batch prediction job on Vertex AI.
generate_explanation (bool):
Required. Generate explanation along with the batch prediction
results.
parent (str):
Required. Also known as common location path, that usually contains the
project and location that the user provided to the upstream method.
Example: "projects/my-prj/locations/us-central1"
project (str):
Required. Project to upload this model to. Overrides project set in
aiplatform.init.
location (str):
Required. Location to upload this model to. Overrides location set in
aiplatform.init.
credentials (Optional[auth_credentials.Credentials]):
Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.

Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Expand All @@ -725,21 +707,34 @@ def _create(
by Vertex AI.
"""
# select v1beta1 if explain else use default v1

parent = initializer.global_config.common_location_path(
project=empty_batch_prediction_job.project,
location=empty_batch_prediction_job.location,
)

model_resource_name = (
model_or_model_name
if isinstance(model_or_model_name, str)
else model_or_model_name.resource_name
)

gca_batch_prediction_job.model = model_resource_name

api_client = empty_batch_prediction_job.api_client

if generate_explanation:
api_client = api_client.select_version(compat.V1BETA1)

_LOGGER.log_create_with_lro(cls)

gca_batch_prediction_job = api_client.create_batch_prediction_job(
parent=parent, batch_prediction_job=batch_prediction_job
parent=parent, batch_prediction_job=gca_batch_prediction_job
)

batch_prediction_job = cls(
batch_prediction_job_name=gca_batch_prediction_job.name,
project=project,
location=location,
credentials=credentials,
)
empty_batch_prediction_job._gca_resource = gca_batch_prediction_job

batch_prediction_job = empty_batch_prediction_job

_LOGGER.log_create_complete(cls, batch_prediction_job._gca_resource, "bpj")

Expand Down Expand Up @@ -843,6 +838,10 @@ def iter_outputs(
f"on your prediction output:\n{output_info}"
)

def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()


class _RunnableJob(_Job):
"""ABC to interface job as a runnable training class."""
Expand Down
4 changes: 1 addition & 3 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,6 @@ def undeploy(
if deployed_model_id in traffic_split and traffic_split[deployed_model_id]:
raise ValueError("Model being undeployed should have 0 traffic.")
if sum(traffic_split.values()) != 100:
# TODO(b/172678233) verify every referenced deployed model exists
raise ValueError(
"Sum of all traffic within traffic split needs to be 100."
)
Expand Down Expand Up @@ -2167,11 +2166,10 @@ def batch_predict(
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
self.wait()

return jobs.BatchPredictionJob.create(
job_display_name=job_display_name,
model_name=self.resource_name,
model_name=self,
instances_format=instances_format,
predictions_format=predictions_format,
gcs_source=gcs_source,
Expand Down
11 changes: 11 additions & 0 deletions tests/system/aiplatform/e2e_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def _temp_prefix(cls) -> str:
"""
pass

@classmethod
def _make_display_name(cls, key: str) -> str:
"""Helper method to make unique display_names.

Args:
key (str): Required. Identifier for the display name.
Returns:
Unique display name.
"""
return f"{cls._temp_prefix}-{key}-{uuid.uuid4()}"

def setup_method(self):
importlib.reload(initializer)
importlib.reload(aiplatform)
Expand Down
Loading