diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index fe47ef3c5a..667d4a558f 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -4736,7 +4736,7 @@ def run( model_labels: Optional[Dict[str, str]] = None, disable_early_stopping: bool = False, sync: bool = True, - create_request_timeout: Optional[float] = False, + create_request_timeout: Optional[float] = None, ) -> models.Model: """Runs the AutoML Image training job and returns a model. diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 06eaa9218a..7ee985ab70 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -4737,6 +4737,160 @@ def test_run_call_pipeline_service_create_with_tabular_dataset_with_timeout( timeout=180.0, ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_tabular_dataset_with_timeout_not_explicitly_set( + self, + mock_pipeline_service_create, + mock_pipeline_service_get, + mock_tabular_dataset, + mock_model_service_get, + sync, + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = training_jobs.CustomPythonPackageTrainingJob( + display_name=_TEST_DISPLAY_NAME, + labels=_TEST_LABELS, + python_package_gcs_uri=_TEST_OUTPUT_PYTHON_PACKAGE_PATH, + python_module_name=_TEST_PYTHON_MODULE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + model_serving_container_image_uri=_TEST_SERVING_CONTAINER_IMAGE, + model_serving_container_predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + model_serving_container_health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + model_serving_container_command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + model_serving_container_args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + model_serving_container_environment_variables=_TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, + model_serving_container_ports=_TEST_MODEL_SERVING_CONTAINER_PORTS, + model_description=_TEST_MODEL_DESCRIPTION, + model_instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + model_parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + model_prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ) + + model_from_job = job.run( + dataset=mock_tabular_dataset, + model_display_name=_TEST_MODEL_DISPLAY_NAME, + model_labels=_TEST_MODEL_LABELS, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + args=_TEST_RUN_ARGS, + environment_variables=_TEST_ENVIRONMENT_VARIABLES, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + training_fraction_split=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction_split=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction_split=_TEST_TEST_FRACTION_SPLIT, + sync=sync, + ) + + if not sync: + model_from_job.wait() + + true_args = _TEST_RUN_ARGS + true_env = [ + {"name": key, "value": value} + for key, value in _TEST_ENVIRONMENT_VARIABLES.items() + ] + + true_worker_pool_spec = { + "replica_count": _TEST_REPLICA_COUNT, + "machine_spec": { + "machine_type": _TEST_MACHINE_TYPE, + "accelerator_type": _TEST_ACCELERATOR_TYPE, + "accelerator_count": _TEST_ACCELERATOR_COUNT, + }, + "disk_spec": { + "boot_disk_type": _TEST_BOOT_DISK_TYPE_DEFAULT, + "boot_disk_size_gb": _TEST_BOOT_DISK_SIZE_GB_DEFAULT, + }, + "python_package_spec": { + "executor_image_uri": _TEST_TRAINING_CONTAINER_IMAGE, + "python_module": _TEST_PYTHON_MODULE_NAME, + "package_uris": [_TEST_OUTPUT_PYTHON_PACKAGE_PATH], + "args": true_args, + "env": true_env, + }, + } + + true_fraction_split = gca_training_pipeline.FractionSplit( + training_fraction=_TEST_TRAINING_FRACTION_SPLIT, + validation_fraction=_TEST_VALIDATION_FRACTION_SPLIT, + test_fraction=_TEST_TEST_FRACTION_SPLIT, + ) + + env = [ + gca_env_var.EnvVar(name=str(key), value=str(value)) + for key, value in _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES.items() + ] + + ports = [ + gca_model.Port(container_port=port) + for port in _TEST_MODEL_SERVING_CONTAINER_PORTS + ] + + true_container_spec = gca_model.ModelContainerSpec( + image_uri=_TEST_SERVING_CONTAINER_IMAGE, + predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, + health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, + command=_TEST_MODEL_SERVING_CONTAINER_COMMAND, + args=_TEST_MODEL_SERVING_CONTAINER_ARGS, + env=env, + ports=ports, + ) + + true_managed_model = gca_model.Model( + display_name=_TEST_MODEL_DISPLAY_NAME, + labels=_TEST_MODEL_LABELS, + description=_TEST_MODEL_DESCRIPTION, + container_spec=true_container_spec, + predict_schemata=gca_model.PredictSchemata( + instance_schema_uri=_TEST_MODEL_INSTANCE_SCHEMA_URI, + parameters_schema_uri=_TEST_MODEL_PARAMETERS_SCHEMA_URI, + prediction_schema_uri=_TEST_MODEL_PREDICTION_SCHEMA_URI, + ), + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + fraction_split=true_fraction_split, + dataset_id=mock_tabular_dataset.name, + gcs_destination=gca_io.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + labels=_TEST_LABELS, + training_task_definition=schema.training_job.definition.custom_task, + training_task_inputs=json_format.ParseDict( + { + "worker_pool_specs": [true_worker_pool_spec], + "base_output_directory": { + "output_uri_prefix": _TEST_BASE_OUTPUT_DIR + }, + "service_account": _TEST_SERVICE_ACCOUNT, + "network": _TEST_NETWORK, + }, + struct_pb2.Value(), + ), + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + encryption_spec=_TEST_DEFAULT_ENCRYPTION_SPEC, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + timeout=None, + ) + @pytest.mark.parametrize("sync", [True, False]) def test_run_call_pipeline_service_create_with_tabular_dataset_without_model_display_name_nor_model_labels( self,