diff --git a/samples/model-builder/create_training_pipeline_custom_container_job_sample.py b/samples/model-builder/create_training_pipeline_custom_container_job_sample.py index 0f63cc36f2..56a6494bc9 100644 --- a/samples/model-builder/create_training_pipeline_custom_container_job_sample.py +++ b/samples/model-builder/create_training_pipeline_custom_container_job_sample.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud import aiplatform from typing import List, Optional, Union +from google.cloud import aiplatform + # [START aiplatform_sdk_create_training_pipeline_custom_container_job_sample] def create_training_pipeline_custom_container_job_sample( @@ -22,8 +23,8 @@ def create_training_pipeline_custom_container_job_sample( location: str, staging_bucket: str, display_name: str, - container_uri: str, - model_serving_container_image_uri: str, + container_uri: str, + model_serving_container_image_uri: str, model_display_name: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, replica_count: int = 1, @@ -52,7 +53,7 @@ def create_training_pipeline_custom_container_job_sample( accelerator_count=accelerator_count, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, - test_fraction_split=test_fraction_split, + test_fraction_split=test_fraction_split, sync=sync, ) diff --git a/samples/model-builder/create_training_pipeline_custom_container_job_test.py b/samples/model-builder/create_training_pipeline_custom_container_job_test.py index 4d1cc2ac5e..2037b8f74e 100644 --- a/samples/model-builder/create_training_pipeline_custom_container_job_test.py +++ b/samples/model-builder/create_training_pipeline_custom_container_job_test.py @@ -42,9 +42,11 @@ def test_create_training_pipeline_custom_container_job_sample( ) mock_sdk_init.assert_called_once_with( - project=constants.PROJECT, location=constants.LOCATION, staging_bucket=constants.STAGING_BUCKET + project=constants.PROJECT, + location=constants.LOCATION, + staging_bucket=constants.STAGING_BUCKET, ) - + mock_init_custom_container_training_job.assert_called_once_with( display_name=constants.DISPLAY_NAME, container_uri=constants.CONTAINER_URI, diff --git a/samples/model-builder/create_training_pipeline_custom_package_job_sample.py b/samples/model-builder/create_training_pipeline_custom_package_job_sample.py index f27a0bbc3a..627a1b5dab 100644 --- a/samples/model-builder/create_training_pipeline_custom_package_job_sample.py +++ b/samples/model-builder/create_training_pipeline_custom_package_job_sample.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud import aiplatform from typing import List, Optional, Union +from google.cloud import aiplatform + # [START aiplatform_sdk_create_training_pipeline_custom_package_job_sample] def create_training_pipeline_custom_package_job_sample( @@ -24,8 +25,8 @@ def create_training_pipeline_custom_package_job_sample( display_name: str, python_package_gcs_uri: str, python_module_name: str, - container_uri: str, - model_serving_container_image_uri: str, + container_uri: str, + model_serving_container_image_uri: str, model_display_name: Optional[str] = None, args: Optional[List[Union[str, float, int]]] = None, replica_count: int = 1, @@ -56,7 +57,7 @@ def create_training_pipeline_custom_package_job_sample( accelerator_count=accelerator_count, training_fraction_split=training_fraction_split, validation_fraction_split=validation_fraction_split, - test_fraction_split=test_fraction_split, + test_fraction_split=test_fraction_split, sync=sync, ) diff --git a/samples/model-builder/create_training_pipeline_custom_package_job_test.py b/samples/model-builder/create_training_pipeline_custom_package_job_test.py index fdb3bb3dae..cf9daf0ac4 100644 --- a/samples/model-builder/create_training_pipeline_custom_package_job_test.py +++ b/samples/model-builder/create_training_pipeline_custom_package_job_test.py @@ -29,7 +29,7 @@ def test_create_training_pipeline_custom_package_job_sample( staging_bucket=constants.STAGING_BUCKET, display_name=constants.DISPLAY_NAME, python_package_gcs_uri=constants.PYTHON_PACKAGE_GCS_URI, - python_module_name=constants.PYTHON_MODULE_NAME, + python_module_name=constants.PYTHON_MODULE_NAME, container_uri=constants.CONTAINER_URI, args=constants.ARGS, model_serving_container_image_uri=constants.CONTAINER_URI, @@ -44,9 +44,11 @@ def test_create_training_pipeline_custom_package_job_sample( ) mock_sdk_init.assert_called_once_with( - project=constants.PROJECT, location=constants.LOCATION, staging_bucket=constants.STAGING_BUCKET + project=constants.PROJECT, + location=constants.LOCATION, + staging_bucket=constants.STAGING_BUCKET, ) - + mock_init_custom_package_training_job.assert_called_once_with( display_name=constants.DISPLAY_NAME, python_package_gcs_uri=constants.PYTHON_PACKAGE_GCS_URI,