From 96a850f2d24d7ae95f2cdec83a56362abecb85a2 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 1 Dec 2020 13:55:27 -0800 Subject: [PATCH] feat: add create_batch_prediction_job samples (#67) * chore: sample tests lint * lint * lnt * lint * feat: add create_batch_prediction_job samples * lint --- ...te_batch_prediction_job_bigquery_sample.py | 62 ++++++++++++++ ...tch_prediction_job_bigquery_sample_test.py | 82 ++++++++++++++++++ .../create_batch_prediction_job_sample.py | 69 +++++++++++++++ ...create_batch_prediction_job_sample_test.py | 83 +++++++++++++++++++ 4 files changed, 296 insertions(+) create mode 100644 samples/snippets/create_batch_prediction_job_bigquery_sample.py create mode 100644 samples/snippets/create_batch_prediction_job_bigquery_sample_test.py create mode 100644 samples/snippets/create_batch_prediction_job_sample.py create mode 100644 samples/snippets/create_batch_prediction_job_sample_test.py diff --git a/samples/snippets/create_batch_prediction_job_bigquery_sample.py b/samples/snippets/create_batch_prediction_job_bigquery_sample.py new file mode 100644 index 0000000000..7747333cab --- /dev/null +++ b/samples/snippets/create_batch_prediction_job_bigquery_sample.py @@ -0,0 +1,62 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_create_batch_prediction_job_bigquery_sample] +from google.cloud import aiplatform +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Value + + +def create_batch_prediction_job_bigquery_sample( + project: str, + display_name: str, + model_name: str, + instances_format: str, + bigquery_source_input_uri: str, + predictions_format: str, + bigquery_destination_output_uri: str, + location: str = "us-central1", + api_endpoint: str = "us-central1-aiplatform.googleapis.com", +): + client_options = {"api_endpoint": api_endpoint} + # Initialize client that will be used to create and send requests. + # This client only needs to be created once, and can be reused for multiple requests. + client = aiplatform.gapic.JobServiceClient(client_options=client_options) + model_parameters_dict = {} + model_parameters = json_format.ParseDict(model_parameters_dict, Value()) + + batch_prediction_job = { + "display_name": display_name, + # Format: 'projects/{project}/locations/{location}/models/{model_id}' + "model": model_name, + "model_parameters": model_parameters, + "input_config": { + "instances_format": instances_format, + "bigquery_source": {"input_uri": bigquery_source_input_uri}, + }, + "output_config": { + "predictions_format": predictions_format, + "bigquery_destination": {"output_uri": bigquery_destination_output_uri}, + }, + # optional + "generate_explanation": True, + } + parent = f"projects/{project}/locations/{location}" + response = client.create_batch_prediction_job( + parent=parent, batch_prediction_job=batch_prediction_job + ) + print("response:", response) + + +# [END aiplatform_create_batch_prediction_job_bigquery_sample] diff --git a/samples/snippets/create_batch_prediction_job_bigquery_sample_test.py b/samples/snippets/create_batch_prediction_job_bigquery_sample_test.py new file mode 100644 index 0000000000..663180ef35 --- /dev/null +++ b/samples/snippets/create_batch_prediction_job_bigquery_sample_test.py @@ -0,0 +1,82 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from uuid import uuid4 + +from google.cloud import aiplatform +import pytest + +import create_batch_prediction_job_bigquery_sample +import helpers + +PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT") +LOCATION = "us-central1" +MODEL_ID = "3125638878883479552" # bq all +DISPLAY_NAME = f"temp_create_batch_prediction_job_test_{uuid4()}" +BIGQUERY_SOURCE_INPUT_URI = "bq://ucaip-sample-tests.table_test.all_bq_types" +BIGQUERY_DESTINATION_OUTPUT_URI = "bq://ucaip-sample-tests" +INSTANCES_FORMAT = "bigquery" +PREDICTIONS_FORMAT = "bigquery" + + +@pytest.fixture +def shared_state(): + state = {} + yield state + + +@pytest.fixture +def job_client(): + job_client = aiplatform.gapic.JobServiceClient( + client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"} + ) + return job_client + + +@pytest.fixture(scope="function", autouse=True) +def teardown(shared_state, job_client): + yield + + job_client.cancel_batch_prediction_job(name=shared_state["batch_prediction_job_name"]) + + # Waiting until the job is in CANCELLED state. + helpers.wait_for_job_state( + get_job_method=job_client.get_batch_prediction_job, + name=shared_state["batch_prediction_job_name"], + ) + + job_client.delete_batch_prediction_job(name=shared_state["batch_prediction_job_name"]) + + +def test_ucaip_generated_create_batch_prediction_job_bigquery_sample( + capsys, shared_state +): + + model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}" + + create_batch_prediction_job_bigquery_sample.create_batch_prediction_job_bigquery_sample( + project=PROJECT_ID, + display_name=DISPLAY_NAME, + model_name=model_name, + bigquery_source_input_uri=BIGQUERY_SOURCE_INPUT_URI, + bigquery_destination_output_uri=BIGQUERY_DESTINATION_OUTPUT_URI, + instances_format=INSTANCES_FORMAT, + predictions_format=PREDICTIONS_FORMAT, + ) + + out, _ = capsys.readouterr() + + # Save resource name of the newly created batch prediction job + shared_state["batch_prediction_job_name"] = helpers.get_name(out) diff --git a/samples/snippets/create_batch_prediction_job_sample.py b/samples/snippets/create_batch_prediction_job_sample.py new file mode 100644 index 0000000000..ea89e7b885 --- /dev/null +++ b/samples/snippets/create_batch_prediction_job_sample.py @@ -0,0 +1,69 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_create_batch_prediction_job_sample] +from google.cloud import aiplatform +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Value + + +def create_batch_prediction_job_sample( + project: str, + display_name: str, + model_name: str, + instances_format: str, + gcs_source_uri: str, + predictions_format: str, + gcs_destination_output_uri_prefix: str, + location: str = "us-central1", + api_endpoint: str = "us-central1-aiplatform.googleapis.com", +): + client_options = {"api_endpoint": api_endpoint} + # Initialize client that will be used to create and send requests. + # This client only needs to be created once, and can be reused for multiple requests. + client = aiplatform.gapic.JobServiceClient(client_options=client_options) + model_parameters_dict = {} + model_parameters = json_format.ParseDict(model_parameters_dict, Value()) + + batch_prediction_job = { + "display_name": display_name, + # Format: 'projects/{project}/locations/{location}/models/{model_id}' + "model": model_name, + "model_parameters": model_parameters, + "input_config": { + "instances_format": instances_format, + "gcs_source": {"uris": [gcs_source_uri]}, + }, + "output_config": { + "predictions_format": predictions_format, + "gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix}, + }, + "dedicated_resources": { + "machine_spec": { + "machine_type": "n1-standard-2", + "accelerator_type": aiplatform.gapic.AcceleratorType.NVIDIA_TESLA_K80, + "accelerator_count": 1, + }, + "starting_replica_count": 1, + "max_replica_count": 1, + }, + } + parent = f"projects/{project}/locations/{location}" + response = client.create_batch_prediction_job( + parent=parent, batch_prediction_job=batch_prediction_job + ) + print("response:", response) + + +# [END aiplatform_create_batch_prediction_job_sample] diff --git a/samples/snippets/create_batch_prediction_job_sample_test.py b/samples/snippets/create_batch_prediction_job_sample_test.py new file mode 100644 index 0000000000..6804928024 --- /dev/null +++ b/samples/snippets/create_batch_prediction_job_sample_test.py @@ -0,0 +1,83 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from uuid import uuid4 + +from google.cloud import aiplatform +import pytest + +import create_batch_prediction_job_sample +import helpers + +PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT") +LOCATION = "us-central1" +MODEL_ID = "1478306577684365312" # Permanent 50 flowers model +DISPLAY_NAME = f"temp_create_batch_prediction_job_test_{uuid4()}" +GCS_SOURCE_URI = ( + "gs://ucaip-samples-test-output/inputs/icn_batch_prediction_input.jsonl" +) +GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/" +INSTANCES_FORMAT = "jsonl" +PREDICTIONS_FORMAT = "jsonl" + + +@pytest.fixture +def shared_state(): + state = {} + yield state + + +@pytest.fixture +def job_client(): + job_client = aiplatform.gapic.JobServiceClient( + client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"} + ) + return job_client + + +@pytest.fixture(scope="function", autouse=True) +def teardown(shared_state, job_client): + yield + + job_client.cancel_batch_prediction_job(name=shared_state["batch_prediction_job_name"]) + + # Waiting until the job is in CANCELLED state. + helpers.wait_for_job_state( + get_job_method=job_client.get_batch_prediction_job, + name=shared_state["batch_prediction_job_name"], + ) + + job_client.delete_batch_prediction_job(name=shared_state["batch_prediction_job_name"]) + + +# Creating AutoML Vision Classification batch prediction job +def test_ucaip_generated_create_batch_prediction_sample(capsys, shared_state): + + model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}" + + create_batch_prediction_job_sample.create_batch_prediction_job_sample( + project=PROJECT_ID, + display_name=DISPLAY_NAME, + model_name=model_name, + gcs_source_uri=GCS_SOURCE_URI, + gcs_destination_output_uri_prefix=GCS_OUTPUT_URI, + instances_format=INSTANCES_FORMAT, + predictions_format=PREDICTIONS_FORMAT, + ) + + out, _ = capsys.readouterr() + + # Save resource name of the newly created batch prediction job + shared_state["batch_prediction_job_name"] = helpers.get_name(out)