From 3edf694df1c684c166f1c1f03fd4a2824ab96e9b Mon Sep 17 00:00:00 2001 From: Ivan Cheung Date: Mon, 4 Jan 2021 23:21:40 +0900 Subject: [PATCH] feat: Added tabular forecasting sample --- ...ediction_job_tabular_forecasting_sample.py | 28 ++++++++++ ...ediction_job_tabular_forecasting_sample.py | 54 +++++++++++++++++++ ...ion_job_tabular_forecasting_sample_test.py | 54 +++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 .sample_configs/param_handlers/create_batch_prediction_job_tabular_forecasting_sample.py create mode 100644 samples/snippets/create_batch_prediction_job_tabular_forecasting_sample.py create mode 100644 samples/snippets/create_batch_prediction_job_tabular_forecasting_sample_test.py diff --git a/.sample_configs/param_handlers/create_batch_prediction_job_tabular_forecasting_sample.py b/.sample_configs/param_handlers/create_batch_prediction_job_tabular_forecasting_sample.py new file mode 100644 index 0000000000..d03f13dff1 --- /dev/null +++ b/.sample_configs/param_handlers/create_batch_prediction_job_tabular_forecasting_sample.py @@ -0,0 +1,28 @@ +def make_parent(parent: str) -> str: + parent = parent + + return parent + + +def make_batch_prediction_job( + display_name: str, + model_name: str, + gcs_source_uri: str, + gcs_destination_output_uri_prefix: str, + predictions_format: str, +) -> google.cloud.aiplatform_v1beta1.types.batch_prediction_job.BatchPredictionJob: + batch_prediction_job = { + "display_name": display_name, + # Format: 'projects/{project}/locations/{location}/models/{model_id}' + "model": model_name, + "input_config": { + "instances_format": predictions_format, + "gcs_source": {"uris": [gcs_source_uri]}, + }, + "output_config": { + "predictions_format": predictions_format, + "gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix}, + }, + } + return batch_prediction_job + diff --git a/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample.py b/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample.py new file mode 100644 index 0000000000..62eee08856 --- /dev/null +++ b/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample.py @@ -0,0 +1,54 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# [START aiplatform_create_batch_prediction_job_tabular_forecasting_sample] +from google.cloud import aiplatform + + +def create_batch_prediction_job_tabular_forecasting_sample( + project: str, + display_name: str, + model_name: str, + gcs_source_uri: str, + gcs_destination_output_uri_prefix: str, + predictions_format: str, + location: str = "us-central1", + api_endpoint: str = "us-central1-aiplatform.googleapis.com", +): + # The AI Platform services require regional API endpoints. + client_options = {"api_endpoint": api_endpoint} + # Initialize client that will be used to create and send requests. + # This client only needs to be created once, and can be reused for multiple requests. + client = aiplatform.gapic.JobServiceClient(client_options=client_options) + batch_prediction_job = { + "display_name": display_name, + # Format: 'projects/{project}/locations/{location}/models/{model_id}' + "model": model_name, + "input_config": { + "instances_format": predictions_format, + "gcs_source": {"uris": [gcs_source_uri]}, + }, + "output_config": { + "predictions_format": predictions_format, + "gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix}, + }, + } + parent = f"projects/{project}/locations/{location}" + response = client.create_batch_prediction_job( + parent=parent, batch_prediction_job=batch_prediction_job + ) + print("response:", response) + + +# [END aiplatform_create_batch_prediction_job_tabular_forecasting_sample] diff --git a/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample_test.py b/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample_test.py new file mode 100644 index 0000000000..f992cc575e --- /dev/null +++ b/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample_test.py @@ -0,0 +1,54 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from uuid import uuid4 + +import pytest + +import create_batch_prediction_job_tabular_forecasting_sample +import helpers + +PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT") +LOCATION = "us-central1" +MODEL_ID = "8531330622239539200" # Permanent restaurant rating model +DISPLAY_NAME = f"temp_create_batch_prediction_tabular_forecasting_test_{uuid4()}" +GCS_SOURCE_URI = "gs://cloud-samples-data/ai-platform/covid/bigquery-public-covid-nyt-us-counties-train.csv" +GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/" +PREDICTIONS_FORMAT = "csv" + + +@pytest.fixture(scope="function", autouse=True) +def teardown(teardown_batch_prediction_job): + yield + + +# Creating AutoML Tabular Forecasting Classification batch prediction job +def test_create_batch_prediction_job_tabular_forecasting_sample(capsys, shared_state): + + model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}" + + create_batch_prediction_job_tabular_forecasting_sample.create_batch_prediction_job_tabular_forecasting_sample( + project=PROJECT_ID, + display_name=DISPLAY_NAME, + model_name=model_name, + gcs_source_uri=GCS_SOURCE_URI, + gcs_destination_output_uri_prefix=GCS_OUTPUT_URI, + predictions_format=PREDICTIONS_FORMAT, + ) + + out, _ = capsys.readouterr() + + # Save resource name of the newly created batch prediction job + shared_state["batch_prediction_job_name"] = helpers.get_name(out)