Skip to content

Commit

Permalink
feat: adds text batch prediction samples (#82)
Browse files Browse the repository at this point in the history
* feat: adds text batch prediction samples

* fix: lint

* fix: broken test

* fix: tests

* fix: more changes

* fix: TSA batch prediction test updates

* fix: working model (I hope!)

Co-authored-by: Yu-Han Liu <[email protected]>
  • Loading branch information
telpirion and dizcology authored Nov 26, 2020
1 parent b012283 commit ad09c29
Show file tree
Hide file tree
Showing 13 changed files with 431 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# [START aiplatform_create_batch_prediction_job_text_classification_sample]
from google.cloud import aiplatform
from google.protobuf.struct_pb2 import Value


def create_batch_prediction_job_text_classification_sample(
project: str,
display_name: str,
model: str,
gcs_source_uri: str,
gcs_destination_output_uri_prefix: str,
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)

batch_prediction_job = {
"display_name": display_name,
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
"model": model,
"model_parameters": Value(),
"input_config": {
"instances_format": "jsonl",
"gcs_source": {"uris": [gcs_source_uri]},
},
"output_config": {
"predictions_format": "jsonl",
"gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix},
},
}
parent = f"projects/{project}/locations/{location}"
response = client.create_batch_prediction_job(
parent=parent, batch_prediction_job=batch_prediction_job
)
print("response:", response)


# [END aiplatform_create_batch_prediction_job_text_classification_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import uuid4
import pytest
import os

import helpers

import create_batch_prediction_job_text_classification_sample
import cancel_batch_prediction_job_sample
import delete_batch_prediction_job_sample

from google.cloud import aiplatform

PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
LOCATION = "us-central1"
MODEL_ID = "3863595899074641920" # Permanent restaurant rating model
DISPLAY_NAME = f"temp_create_batch_prediction_tcn_test_{uuid4()}"
GCS_SOURCE_URI = (
"gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl"
)
GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"


@pytest.fixture(scope="function")
def shared_state():

shared_state = {}

yield shared_state

assert "/" in shared_state["batch_prediction_job_name"]

batch_prediction_job = shared_state["batch_prediction_job_name"].split("/")[-1]

# Stop the batch prediction job
cancel_batch_prediction_job_sample.cancel_batch_prediction_job_sample(
project=PROJECT_ID, batch_prediction_job_id=batch_prediction_job
)

job_client = aiplatform.gapic.JobServiceClient(
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
)

# Waiting for batch prediction job to be in CANCELLED state
helpers.wait_for_job_state(
get_job_method=job_client.get_batch_prediction_job,
name=shared_state["batch_prediction_job_name"],
)

# Delete the batch prediction job
delete_batch_prediction_job_sample.delete_batch_prediction_job_sample(
project=PROJECT_ID, batch_prediction_job_id=batch_prediction_job
)


# Creating AutoML Text Classification batch prediction job
def test_ucaip_generated_create_batch_prediction_tcn_sample(capsys, shared_state):

model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"

create_batch_prediction_job_text_classification_sample.create_batch_prediction_job_text_classification_sample(
project=PROJECT_ID,
display_name=DISPLAY_NAME,
model=model_name,
gcs_source_uri=GCS_SOURCE_URI,
gcs_destination_output_uri_prefix=GCS_OUTPUT_URI,
)

out, _ = capsys.readouterr()

# Save resource name of the newly created batch prediction job
shared_state["batch_prediction_job_name"] = helpers.get_name(out)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# [START aiplatform_create_batch_prediction_job_text_entity_extraction_sample]
from google.cloud import aiplatform
from google.protobuf.struct_pb2 import Value


def create_batch_prediction_job_text_entity_extraction_sample(
project: str,
display_name: str,
model: str,
gcs_source_uri: str,
gcs_destination_output_uri_prefix: str,
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)

batch_prediction_job = {
"display_name": display_name,
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
"model": model,
"model_parameters": Value(),
"input_config": {
"instances_format": "jsonl",
"gcs_source": {"uris": [gcs_source_uri]},
},
"output_config": {
"predictions_format": "jsonl",
"gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix},
},
}
parent = f"projects/{project}/locations/{location}"
response = client.create_batch_prediction_job(
parent=parent, batch_prediction_job=batch_prediction_job
)
print("response:", response)


# [END aiplatform_create_batch_prediction_job_text_entity_extraction_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from uuid import uuid4
import pytest
import os

import helpers

import create_batch_prediction_job_text_entity_extraction_sample
import cancel_batch_prediction_job_sample
import delete_batch_prediction_job_sample

from google.cloud import aiplatform

PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
LOCATION = "us-central1"
MODEL_ID = "5216364637146054656" # Permanent medical entity NL model
DISPLAY_NAME = f"temp_create_batch_prediction_ten_test_{uuid4()}"
GCS_SOURCE_URI = (
"gs://ucaip-samples-test-output/inputs/batch_predict_TEN/ten_inputs.jsonl"
)
GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"


@pytest.fixture(scope="function")
def shared_state():

shared_state = {}

yield shared_state

assert "/" in shared_state["batch_prediction_job_name"]

batch_prediction_job = shared_state["batch_prediction_job_name"].split("/")[-1]

# Stop the batch prediction job
cancel_batch_prediction_job_sample.cancel_batch_prediction_job_sample(
project=PROJECT_ID, batch_prediction_job_id=batch_prediction_job
)

job_client = aiplatform.gapic.JobServiceClient(
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
)

# Waiting for batch prediction job to be in CANCELLED state
helpers.wait_for_job_state(
get_job_method=job_client.get_batch_prediction_job,
name=shared_state["batch_prediction_job_name"],
)

# Delete the batch prediction job
delete_batch_prediction_job_sample.delete_batch_prediction_job_sample(
project=PROJECT_ID, batch_prediction_job_id=batch_prediction_job
)


# Creating AutoML Text Entity Extraction batch prediction job
def test_ucaip_generated_create_batch_prediction_ten_sample(capsys, shared_state):

model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"

create_batch_prediction_job_text_entity_extraction_sample.create_batch_prediction_job_text_entity_extraction_sample(
project=PROJECT_ID,
display_name=DISPLAY_NAME,
model=model_name,
gcs_source_uri=GCS_SOURCE_URI,
gcs_destination_output_uri_prefix=GCS_OUTPUT_URI,
)

out, _ = capsys.readouterr()

# Save resource name of the newly created batch prediction job
shared_state["batch_prediction_job_name"] = helpers.get_name(out)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# [START aiplatform_create_batch_prediction_job_text_sentiment_analysis_sample]
from google.cloud import aiplatform
from google.protobuf.struct_pb2 import Value


def create_batch_prediction_job_text_sentiment_analysis_sample(
project: str,
display_name: str,
model: str,
gcs_source_uri: str,
gcs_destination_output_uri_prefix: str,
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)

batch_prediction_job = {
"display_name": display_name,
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
"model": model,
"model_parameters": Value(),
"input_config": {
"instances_format": "jsonl",
"gcs_source": {"uris": [gcs_source_uri]},
},
"output_config": {
"predictions_format": "jsonl",
"gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix},
},
}
parent = f"projects/{project}/locations/{location}"
response = client.create_batch_prediction_job(
parent=parent, batch_prediction_job=batch_prediction_job
)
print("response:", response)


# [END aiplatform_create_batch_prediction_job_text_sentiment_analysis_sample]
Loading

0 comments on commit ad09c29

Please sign in to comment.