diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 268ab8fdf4..35a10529b8 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -1842,6 +1842,7 @@ def run( holiday_regions: Optional[List[str]] = None, sync: bool = True, create_request_timeout: Optional[float] = None, + enable_probabilistic_inference: bool = False, ) -> models.Model: """Runs the training job and returns a model. @@ -2080,6 +2081,15 @@ def run( Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + enable_probabilistic_inference (bool): + If probabilistic inference is enabled, the model will fit a + distribution that captures the uncertainty of a prediction. At + inference time, the predictive distribution is used to make a + point prediction that minimizes the optimization objective. For + example, the mean of a predictive distribution is the point + prediction that minimizes RMSE loss. If quantiles are specified, + then the quantiles of the distribution are also returned. The + optimization objective cannot be minimize-quantile-loss. Returns: model: The trained Vertex AI Model resource or None if training did not produce a Vertex AI Model. @@ -2148,6 +2158,7 @@ def run( holiday_regions=holiday_regions, sync=sync, create_request_timeout=create_request_timeout, + enable_probabilistic_inference=enable_probabilistic_inference, ) @base.optional_sync() @@ -2193,6 +2204,7 @@ def _run( holiday_regions: Optional[List[str]] = None, sync: bool = True, create_request_timeout: Optional[float] = None, + enable_probabilistic_inference: bool = False, ) -> models.Model: """Runs the training job and returns a model. @@ -2321,11 +2333,12 @@ def _run( [export_evaluated_data_items_bigquery_destination_uri] is specified. quantiles (List[float]): Quantiles to use for the `minimize-quantile-loss` - [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in - this case. + [AutoMLForecastingTrainingJob.optimization_objective]. This + argument is required in this case. Quantiles may also optionally + be used if probabilistic inference is enabled. - Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive. - Each quantile must be unique. + Accepts up to 5 quantiles in the form of a double from 0 to 1, + exclusive. Each quantile must be unique. validation_options (str): Validation options for the data validation component. The available options are: "fail-pipeline" - (default), will validate against the validation and fail the pipeline @@ -2438,6 +2451,15 @@ def _run( be immediately returned and synced when the Future has completed. create_request_timeout (float): Optional. The timeout for the create request in seconds. + enable_probabilistic_inference (bool): + If probabilistic inference is enabled, the model will fit a + distribution that captures the uncertainty of a prediction. At + inference time, the predictive distribution is used to make a + point prediction that minimizes the optimization objective. For + example, the mean of a predictive distribution is the point + prediction that minimizes RMSE loss. If quantiles are specified, + then the quantiles of the distribution are also returned. The + optimization objective cannot be minimize-quantile-loss. Returns: model: The trained Vertex AI Model resource or None if training did not produce a Vertex AI Model. @@ -2466,8 +2488,18 @@ def _run( max_count=window_max_count, ) - # TODO(b/244643824): Replace additional experiments with a new job arg. - enable_probabilistic_inference = self._convert_enable_probabilistic_inference() + # Probabilistic inference flag should be removed from additional + # experiments in all cases since it is only an additional experiment in + # the SDK. If both are set, always prefer job arg for setting the field. + # TODO(b/244643824): Deprecate probabilistic inference in additional + # experiment and only use job arg. + additional_experiment_probabilistic_inference = ( + self._convert_enable_probabilistic_inference() + ) + if not enable_probabilistic_inference: + enable_probabilistic_inference = ( + additional_experiment_probabilistic_inference + ) training_task_inputs_dict = { # required inputs diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py index 19e3dfde41..c08c70381a 100644 --- a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -98,6 +98,7 @@ _TEST_WINDOW_STRIDE_LENGTH = 1 _TEST_WINDOW_MAX_COUNT = None _TEST_TRAINING_HOLIDAY_REGIONS = ["GLOBAL"] +_TEST_ENABLE_PROBABILISTIC_INFERENCE = True _TEST_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE = [ "exp1", "exp2", @@ -148,15 +149,13 @@ struct_pb2.Value(), ) -_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE = ( - json_format.ParseDict( - { - **_TEST_TRAINING_TASK_INPUTS_DICT, - "additionalExperiments": _TEST_ADDITIONAL_EXPERIMENTS, - "enableProbabilisticInference": True, - }, - struct_pb2.Value(), - ) +_TEST_TRAINING_TASK_INPUTS_WITH_PROBABILISTIC_INFERENCE = json_format.ParseDict( + { + **_TEST_TRAINING_TASK_INPUTS_DICT, + "additionalExperiments": _TEST_ADDITIONAL_EXPERIMENTS, + "enableProbabilisticInference": True, + }, + struct_pb2.Value(), ) _TEST_TRAINING_TASK_INPUTS = json_format.ParseDict( @@ -1284,7 +1283,89 @@ def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference true_training_pipeline = gca_training_pipeline.TrainingPipeline( display_name=_TEST_DISPLAY_NAME, training_task_definition=training_job._training_task_definition, - training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_PROBABILISTIC_INFERENCE, + model_to_upload=true_managed_model, + input_data_config=true_input_data_config, + ) + + mock_pipeline_service_create.assert_called_once_with( + parent=initializer.global_config.common_location_path(), + training_pipeline=true_training_pipeline, + timeout=None, + ) + + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures("mock_pipeline_service_get") + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES) + def test_run_call_pipeline_if_set_enable_probabilistic_inference( + self, + mock_pipeline_service_create, + mock_dataset_time_series, + mock_model_service_get, + sync, + training_job, + ): + aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME) + + job = training_job( + display_name=_TEST_DISPLAY_NAME, + optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME, + column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS, + ) + + job._add_additional_experiments(_TEST_ADDITIONAL_EXPERIMENTS) + + model_from_job = job.run( + dataset=mock_dataset_time_series, + target_column=_TEST_TRAINING_TARGET_COLUMN, + time_column=_TEST_TRAINING_TIME_COLUMN, + time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN, + unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS, + available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS, + forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON, + data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT, + data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT, + weight_column=_TEST_TRAINING_WEIGHT_COLUMN, + time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS, + context_window=_TEST_TRAINING_CONTEXT_WINDOW, + budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS, + export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS, + export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI, + export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION, + quantiles=_TEST_TRAINING_QUANTILES, + validation_options=_TEST_TRAINING_VALIDATION_OPTIONS, + hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS, + hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT, + hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT, + hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT, + window_column=_TEST_WINDOW_COLUMN, + window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, + window_max_count=_TEST_WINDOW_MAX_COUNT, + sync=sync, + create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, + enable_probabilistic_inference=_TEST_ENABLE_PROBABILISTIC_INFERENCE, + ) + + if not sync: + model_from_job.wait() + + # Test that if defaults to the job display name + true_managed_model = gca_model.Model( + display_name=_TEST_DISPLAY_NAME, + version_aliases=["default"], + ) + + true_input_data_config = gca_training_pipeline.InputDataConfig( + dataset_id=mock_dataset_time_series.name, + ) + + true_training_pipeline = gca_training_pipeline.TrainingPipeline( + display_name=_TEST_DISPLAY_NAME, + training_task_definition=training_job._training_task_definition, + training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_PROBABILISTIC_INFERENCE, model_to_upload=true_managed_model, input_data_config=true_input_data_config, )