From 4fb41ec40f2291b86e1d6905b713a52c848c04c8 Mon Sep 17 00:00:00 2001 From: Andrew Ferlitsch Date: Mon, 6 Dec 2021 11:18:32 -0800 Subject: [PATCH 1/4] fix: add param for multi-label per user's feedback --- .../create_training_pipeline_image_classification_sample.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample.py b/samples/model-builder/create_training_pipeline_image_classification_sample.py index 3786894a05..1f625f6c9f 100644 --- a/samples/model-builder/create_training_pipeline_image_classification_sample.py +++ b/samples/model-builder/create_training_pipeline_image_classification_sample.py @@ -24,6 +24,7 @@ def create_training_pipeline_image_classification_sample( display_name: str, dataset_id: int, model_display_name: Optional[str] = None, + multi_label: bool = False, training_fraction_split: float = 0.8, validation_fraction_split: float = 0.1, test_fraction_split: float = 0.1, @@ -33,7 +34,10 @@ def create_training_pipeline_image_classification_sample( ): aiplatform.init(project=project, location=location) - job = aiplatform.AutoMLImageTrainingJob(display_name=display_name) + job = aiplatform.AutoMLImageTrainingJob(display_name=display_name, + prediction_type='classification', + multi_label=multi_label + ) my_image_ds = aiplatform.ImageDataset(dataset_id) From 1b34b3c48cc1674cd4610498959c433c1246321a Mon Sep 17 00:00:00 2001 From: Andrew Ferlitsch Date: Mon, 6 Dec 2021 12:14:39 -0800 Subject: [PATCH 2/4] fix: indentation --- ...eate_training_pipeline_image_classification_sample.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample.py b/samples/model-builder/create_training_pipeline_image_classification_sample.py index 1f625f6c9f..615e468485 100644 --- a/samples/model-builder/create_training_pipeline_image_classification_sample.py +++ b/samples/model-builder/create_training_pipeline_image_classification_sample.py @@ -34,10 +34,11 @@ def create_training_pipeline_image_classification_sample( ): aiplatform.init(project=project, location=location) - job = aiplatform.AutoMLImageTrainingJob(display_name=display_name, - prediction_type='classification', - multi_label=multi_label - ) + job = aiplatform.AutoMLImageTrainingJob( + display_name=display_name, + prediction_type='classification', + multi_label=multi_label + ) my_image_ds = aiplatform.ImageDataset(dataset_id) From 7a25b434e2dca6bab02b964a559aaaf8c88a26c5 Mon Sep 17 00:00:00 2001 From: Andrew Ferlitsch Date: Mon, 6 Dec 2021 12:22:03 -0800 Subject: [PATCH 3/4] test: update assert for new params --- ...eate_training_pipeline_image_classification_sample_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py index 1c7080e7a1..e00c0a7d64 100644 --- a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py +++ b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py @@ -44,7 +44,9 @@ def test_create_training_pipeline_image_classification_sample( project=constants.PROJECT, location=constants.LOCATION ) mock_get_automl_image_training_job.assert_called_once_with( - display_name=constants.DISPLAY_NAME + display_name=constants.DISPLAY_NAME, + multi_label=False, + prediction_type='classification' ) mock_run_automl_image_training_job.assert_called_once_with( dataset=mock_image_dataset, From 1e2c9fbcdaee8ef4e1282e1da6a322e6aa2f1e8a Mon Sep 17 00:00:00 2001 From: Andrew Ferlitsch Date: Mon, 6 Dec 2021 12:31:37 -0800 Subject: [PATCH 4/4] lint: remove trailing whitespace --- ...create_training_pipeline_image_classification_sample_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py index e00c0a7d64..c5d7e14beb 100644 --- a/samples/model-builder/create_training_pipeline_image_classification_sample_test.py +++ b/samples/model-builder/create_training_pipeline_image_classification_sample_test.py @@ -45,7 +45,7 @@ def test_create_training_pipeline_image_classification_sample( ) mock_get_automl_image_training_job.assert_called_once_with( display_name=constants.DISPLAY_NAME, - multi_label=False, + multi_label=False, prediction_type='classification' ) mock_run_automl_image_training_job.assert_called_once_with(