Skip to content

Commit

Permalink
chore: Add enable_probabilistic_inference flag to Vertex Forecast tra…
Browse files Browse the repository at this point in the history
…ining jobs.

PiperOrigin-RevId: 569556089
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Sep 29, 2023
1 parent b9c9048 commit 1aab6fd
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 16 deletions.
44 changes: 38 additions & 6 deletions google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,7 @@ def run(
holiday_regions: Optional[List[str]] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
enable_probabilistic_inference: bool = False,
) -> models.Model:
"""Runs the training job and returns a model.
Expand Down Expand Up @@ -2080,6 +2081,15 @@ def run(
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
enable_probabilistic_inference (bool):
If probabilistic inference is enabled, the model will fit a
distribution that captures the uncertainty of a prediction. At
inference time, the predictive distribution is used to make a
point prediction that minimizes the optimization objective. For
example, the mean of a predictive distribution is the point
prediction that minimizes RMSE loss. If quantiles are specified,
then the quantiles of the distribution are also returned. The
optimization objective cannot be minimize-quantile-loss.
Returns:
model: The trained Vertex AI Model resource or None if training did not
produce a Vertex AI Model.
Expand Down Expand Up @@ -2148,6 +2158,7 @@ def run(
holiday_regions=holiday_regions,
sync=sync,
create_request_timeout=create_request_timeout,
enable_probabilistic_inference=enable_probabilistic_inference,
)

@base.optional_sync()
Expand Down Expand Up @@ -2193,6 +2204,7 @@ def _run(
holiday_regions: Optional[List[str]] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
enable_probabilistic_inference: bool = False,
) -> models.Model:
"""Runs the training job and returns a model.
Expand Down Expand Up @@ -2321,11 +2333,12 @@ def _run(
[export_evaluated_data_items_bigquery_destination_uri] is specified.
quantiles (List[float]):
Quantiles to use for the `minimize-quantile-loss`
[AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
this case.
[AutoMLForecastingTrainingJob.optimization_objective]. This
argument is required in this case. Quantiles may also optionally
be used if probabilistic inference is enabled.
Accepts up to 5 quantiles in the form of a double from 0 to 1, exclusive.
Each quantile must be unique.
Accepts up to 5 quantiles in the form of a double from 0 to 1,
exclusive. Each quantile must be unique.
validation_options (str):
Validation options for the data validation component. The available options are:
"fail-pipeline" - (default), will validate against the validation and fail the pipeline
Expand Down Expand Up @@ -2438,6 +2451,15 @@ def _run(
be immediately returned and synced when the Future has completed.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
enable_probabilistic_inference (bool):
If probabilistic inference is enabled, the model will fit a
distribution that captures the uncertainty of a prediction. At
inference time, the predictive distribution is used to make a
point prediction that minimizes the optimization objective. For
example, the mean of a predictive distribution is the point
prediction that minimizes RMSE loss. If quantiles are specified,
then the quantiles of the distribution are also returned. The
optimization objective cannot be minimize-quantile-loss.
Returns:
model: The trained Vertex AI Model resource or None if training did not
produce a Vertex AI Model.
Expand Down Expand Up @@ -2466,8 +2488,18 @@ def _run(
max_count=window_max_count,
)

# TODO(b/244643824): Replace additional experiments with a new job arg.
enable_probabilistic_inference = self._convert_enable_probabilistic_inference()
# Probabilistic inference flag should be removed from additional
# experiments in all cases since it is only an additional experiment in
# the SDK. If both are set, always prefer job arg for setting the field.
# TODO(b/244643824): Deprecate probabilistic inference in additional
# experiment and only use job arg.
additional_experiment_probabilistic_inference = (
self._convert_enable_probabilistic_inference()
)
if not enable_probabilistic_inference:
enable_probabilistic_inference = (
additional_experiment_probabilistic_inference
)

training_task_inputs_dict = {
# required inputs
Expand Down
101 changes: 91 additions & 10 deletions tests/unit/aiplatform/test_automl_forecasting_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
_TEST_WINDOW_STRIDE_LENGTH = 1
_TEST_WINDOW_MAX_COUNT = None
_TEST_TRAINING_HOLIDAY_REGIONS = ["GLOBAL"]
_TEST_ENABLE_PROBABILISTIC_INFERENCE = True
_TEST_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE = [
"exp1",
"exp2",
Expand Down Expand Up @@ -148,15 +149,13 @@
struct_pb2.Value(),
)

_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE = (
json_format.ParseDict(
{
**_TEST_TRAINING_TASK_INPUTS_DICT,
"additionalExperiments": _TEST_ADDITIONAL_EXPERIMENTS,
"enableProbabilisticInference": True,
},
struct_pb2.Value(),
)
_TEST_TRAINING_TASK_INPUTS_WITH_PROBABILISTIC_INFERENCE = json_format.ParseDict(
{
**_TEST_TRAINING_TASK_INPUTS_DICT,
"additionalExperiments": _TEST_ADDITIONAL_EXPERIMENTS,
"enableProbabilisticInference": True,
},
struct_pb2.Value(),
)

_TEST_TRAINING_TASK_INPUTS = json_format.ParseDict(
Expand Down Expand Up @@ -1284,7 +1283,89 @@ def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference
true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=training_job._training_task_definition,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS_PROBABILISTIC_INFERENCE,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_PROBABILISTIC_INFERENCE,
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
)

mock_pipeline_service_create.assert_called_once_with(
parent=initializer.global_config.common_location_path(),
training_pipeline=true_training_pipeline,
timeout=None,
)

@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.usefixtures("mock_pipeline_service_get")
@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.parametrize("training_job", _FORECASTING_JOB_MODEL_TYPES)
def test_run_call_pipeline_if_set_enable_probabilistic_inference(
self,
mock_pipeline_service_create,
mock_dataset_time_series,
mock_model_service_get,
sync,
training_job,
):
aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME)

job = training_job(
display_name=_TEST_DISPLAY_NAME,
optimization_objective=_TEST_TRAINING_OPTIMIZATION_OBJECTIVE_NAME,
column_transformations=_TEST_TRAINING_COLUMN_TRANSFORMATIONS,
)

job._add_additional_experiments(_TEST_ADDITIONAL_EXPERIMENTS)

model_from_job = job.run(
dataset=mock_dataset_time_series,
target_column=_TEST_TRAINING_TARGET_COLUMN,
time_column=_TEST_TRAINING_TIME_COLUMN,
time_series_identifier_column=_TEST_TRAINING_TIME_SERIES_IDENTIFIER_COLUMN,
unavailable_at_forecast_columns=_TEST_TRAINING_UNAVAILABLE_AT_FORECAST_COLUMNS,
available_at_forecast_columns=_TEST_TRAINING_AVAILABLE_AT_FORECAST_COLUMNS,
forecast_horizon=_TEST_TRAINING_FORECAST_HORIZON,
data_granularity_unit=_TEST_TRAINING_DATA_GRANULARITY_UNIT,
data_granularity_count=_TEST_TRAINING_DATA_GRANULARITY_COUNT,
weight_column=_TEST_TRAINING_WEIGHT_COLUMN,
time_series_attribute_columns=_TEST_TRAINING_TIME_SERIES_ATTRIBUTE_COLUMNS,
context_window=_TEST_TRAINING_CONTEXT_WINDOW,
budget_milli_node_hours=_TEST_TRAINING_BUDGET_MILLI_NODE_HOURS,
export_evaluated_data_items=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS,
export_evaluated_data_items_bigquery_destination_uri=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_BIGQUERY_DESTINATION_URI,
export_evaluated_data_items_override_destination=_TEST_TRAINING_EXPORT_EVALUATED_DATA_ITEMS_OVERRIDE_DESTINATION,
quantiles=_TEST_TRAINING_QUANTILES,
validation_options=_TEST_TRAINING_VALIDATION_OPTIONS,
hierarchy_group_columns=_TEST_HIERARCHY_GROUP_COLUMNS,
hierarchy_group_total_weight=_TEST_HIERARCHY_GROUP_TOTAL_WEIGHT,
hierarchy_temporal_total_weight=_TEST_HIERARCHY_TEMPORAL_TOTAL_WEIGHT,
hierarchy_group_temporal_total_weight=_TEST_HIERARCHY_GROUP_TEMPORAL_TOTAL_WEIGHT,
window_column=_TEST_WINDOW_COLUMN,
window_stride_length=_TEST_WINDOW_STRIDE_LENGTH,
window_max_count=_TEST_WINDOW_MAX_COUNT,
sync=sync,
create_request_timeout=None,
holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS,
enable_probabilistic_inference=_TEST_ENABLE_PROBABILISTIC_INFERENCE,
)

if not sync:
model_from_job.wait()

# Test that if defaults to the job display name
true_managed_model = gca_model.Model(
display_name=_TEST_DISPLAY_NAME,
version_aliases=["default"],
)

true_input_data_config = gca_training_pipeline.InputDataConfig(
dataset_id=mock_dataset_time_series.name,
)

true_training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=_TEST_DISPLAY_NAME,
training_task_definition=training_job._training_task_definition,
training_task_inputs=_TEST_TRAINING_TASK_INPUTS_WITH_PROBABILISTIC_INFERENCE,
model_to_upload=true_managed_model,
input_data_config=true_input_data_config,
)
Expand Down

0 comments on commit 1aab6fd

Please sign in to comment.