Skip to content

Commit

Permalink
feat: Add AutoML vision, Custom training job, and generic prediction …
Browse files Browse the repository at this point in the history
…samples (#300)

* debug mock issue

* new mock

* more samples

* more samples

* add next sample/test

* add sample/test

* run black

* Add new Dataset import mocks, fix MBSDK sample tests

* Add license headers, update Endpoint mocks/usage

* type updates

* sasha comment fixes

* fix test errors after review update

* fix: type for instances

* Lint SDK samples

* Fix flake8 import order nits

Co-authored-by: Vinny Senthil <[email protected]>
Co-authored-by: sasha-gitg <[email protected]>
  • Loading branch information
3 people authored May 12, 2021
1 parent 56273f7 commit cc1a708
Show file tree
Hide file tree
Showing 16 changed files with 599 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from google.cloud import aiplatform


# [START aiplatform_sdk_automl_image_classification_training_job_sample]
def automl_image_classification_training_job_sample(
project: str, location: str, dataset_id: str, display_name: str,
):
aiplatform.init(project=project, location=location)

dataset = aiplatform.ImageDataset(dataset_id)

job = aiplatform.AutoMLImageTrainingJob(
display_name=display_name,
prediction_type="classification",
multi_label=False,
model_type="CLOUD",
base_model=None,
)

model = job.run(
dataset=dataset,
model_display_name=display_name,
training_fraction_split=0.6,
validation_fraction_split=0.2,
test_fraction_split=0.2,
budget_milli_node_hours=8000,
disable_early_stopping=False,
)

print(model.display_name)
print(model.name)
print(model.resource_name)
print(model.description)
print(model.uri)

return model


# [END aiplatform_sdk_automl_image_classification_training_job_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import automl_image_classification_training_job_sample
import test_constants as constants


def test_automl_image_classification_training_job_sample(
mock_sdk_init,
mock_image_dataset,
mock_get_image_dataset,
mock_get_automl_image_training_job,
mock_run_automl_image_training_job,
):
automl_image_classification_training_job_sample.automl_image_classification_training_job_sample(
project=constants.PROJECT,
location=constants.LOCATION,
dataset_id=constants.DATASET_NAME,
display_name=constants.DISPLAY_NAME,
)

mock_get_image_dataset.assert_called_once_with(constants.DATASET_NAME)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)

mock_get_automl_image_training_job.assert_called_once_with(
display_name=constants.DISPLAY_NAME,
base_model=None,
model_type="CLOUD",
multi_label=False,
prediction_type="classification",
)

mock_run_automl_image_training_job.assert_called_once_with(
budget_milli_node_hours=8000,
disable_early_stopping=False,
test_fraction_split=0.2,
training_fraction_split=0.6,
validation_fraction_split=0.2,
model_display_name=constants.DISPLAY_NAME,
dataset=mock_image_dataset,
)
19 changes: 19 additions & 0 deletions samples/model-builder/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ def mock_create_video_dataset(mock_video_dataset):
"""Mocks for SomeDataset.import_data() """


@pytest.fixture
def mock_import_image_dataset(mock_image_dataset):
with patch.object(mock_image_dataset, "import_data") as mock:
yield mock


@pytest.fixture
def mock_import_tabular_dataset(mock_tabular_dataset):
with patch.object(mock_tabular_dataset, "import_data") as mock:
yield mock


@pytest.fixture
def mock_import_text_dataset(mock_text_dataset):
with patch.object(mock_text_dataset, "import_data") as mock:
Expand Down Expand Up @@ -327,6 +339,13 @@ def mock_get_endpoint(mock_endpoint):
yield mock_get_endpoint


@pytest.fixture
def mock_endpoint_predict(mock_endpoint):
with patch.object(mock_endpoint, "predict") as mock:
mock.return_value = []
yield mock


@pytest.fixture
def mock_endpoint_explain(mock_endpoint):
with patch.object(mock_endpoint, "explain") as mock_endpoint_explain:
Expand Down
49 changes: 49 additions & 0 deletions samples/model-builder/custom_training_job_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from google.cloud import aiplatform


# [START aiplatform_sdk_custom_training_job_sample]
def custom_training_job_sample(
project: str,
location: str,
bucket: str,
display_name: str,
script_path: str,
script_args: str,
container_uri: str,
model_serving_container_image_uri: str,
requirements: str,
replica_count: int,
):
aiplatform.init(project=project, location=location, staging_bucket=bucket)

job = aiplatform.CustomTrainingJob(
display_name=display_name,
script_path=script_path,
container_uri=container_uri,
requirements=requirements,
model_serving_container_image_uri=model_serving_container_image_uri,
)

model = job.run(
args=script_args, replica_count=replica_count, model_display_name=display_name
)

return model


# [END aiplatform_sdk_custom_training_job_sample]
50 changes: 50 additions & 0 deletions samples/model-builder/custom_training_job_sample_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import custom_training_job_sample
import test_constants as constants


def test_custom_training_job_sample(
mock_sdk_init, mock_get_custom_training_job, mock_run_custom_training_job
):
custom_training_job_sample.custom_training_job_sample(
project=constants.PROJECT,
location=constants.LOCATION,
bucket=constants.STAGING_BUCKET,
display_name=constants.DISPLAY_NAME,
script_path=constants.PYTHON_PACKAGE,
script_args=constants.PYTHON_PACKAGE_CMDARGS,
container_uri=constants.TRAIN_IMAGE,
model_serving_container_image_uri=constants.DEPLOY_IMAGE,
requirements=[],
replica_count=1,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT,
location=constants.LOCATION,
staging_bucket=constants.STAGING_BUCKET,
)

mock_get_custom_training_job.assert_called_once_with(
display_name=constants.DISPLAY_NAME,
container_uri=constants.TRAIN_IMAGE,
model_serving_container_image_uri=constants.DEPLOY_IMAGE,
requirements=[],
script_path=constants.PYTHON_PACKAGE,
)

mock_run_custom_training_job.assert_called_once()
32 changes: 32 additions & 0 deletions samples/model-builder/endpoint_predict_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from google.cloud import aiplatform


# [START aiplatform_sdk_endpoint_predict_sample]
def endpoint_predict_sample(
project: str, location: str, instances: list, endpoint: str
):
aiplatform.init(project=project, location=location)

endpoint = aiplatform.Endpoint(endpoint)

prediction = endpoint.predict(instances=instances)
print(prediction)
return prediction


# [END aiplatform_sdk_endpoint_predict_sample]
37 changes: 37 additions & 0 deletions samples/model-builder/endpoint_predict_sample_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import endpoint_predict_sample
import test_constants as constants


def test_endpoint_predict_sample(
mock_sdk_init, mock_endpoint_predict, mock_get_endpoint
):

endpoint_predict_sample.endpoint_predict_sample(
project=constants.PROJECT,
location=constants.LOCATION,
instances=[],
endpoint=constants.ENDPOINT_NAME,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)

mock_get_endpoint.assert_called_once_with(constants.ENDPOINT_NAME)

mock_endpoint_predict.assert_called_once_with(instances=[])
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from google.cloud import aiplatform


# [START aiplatform_sdk_image_dataset_create_classification_sample]
def image_dataset_create_classification_sample(
project: str, location: str, display_name: str, src_uris: list
):
aiplatform.init(project=project, location=location)

ds = aiplatform.ImageDataset.create(
display_name=display_name,
gcs_source=src_uris,
import_schema_uri=aiplatform.schema.dataset.ioformat.image.single_label_classification,
)

print(ds.display_name)
print(ds.name)
print(ds.resource_name)
print(ds.metadata_schema_uri)
return ds


# [END aiplatform_sdk_image_dataset_create_classification_sample]
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from google.cloud.aiplatform import schema

import image_dataset_create_classification_sample

import test_constants as constants


def test_image_dataset_create_classification_sample(
mock_sdk_init, mock_create_image_dataset
):
image_dataset_create_classification_sample.image_dataset_create_classification_sample(
project=constants.PROJECT,
location=constants.LOCATION,
src_uris=constants.GCS_SOURCES,
display_name=constants.DISPLAY_NAME,
)

mock_sdk_init.assert_called_once_with(
project=constants.PROJECT, location=constants.LOCATION
)
mock_create_image_dataset.assert_called_once_with(
display_name=constants.DISPLAY_NAME,
gcs_source=constants.GCS_SOURCES,
import_schema_uri=schema.dataset.ioformat.image.single_label_classification,
)
Loading

0 comments on commit cc1a708

Please sign in to comment.