Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add hierarchy and window configs to Vertex Forecasting training job #1255

Merged
merged 7 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 131 additions & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4037,6 +4037,13 @@ def run(
model_display_name: Optional[str] = None,
model_labels: Optional[Dict[str, str]] = None,
additional_experiments: Optional[List[str]] = None,
hierarchy_group_columns: Optional[List[str]] = None,
hierarchy_group_total_weight: Optional[float] = None,
hierarchy_temporal_total_weight: Optional[float] = None,
hierarchy_group_temporal_total_weight: Optional[float] = None,
window_column: Optional[str] = None,
window_stride_length: Optional[int] = None,
window_max_count: Optional[int] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> models.Model:
Expand Down Expand Up @@ -4157,7 +4164,7 @@ def run(
Applies only if [export_evaluated_data_items] is True and
[export_evaluated_data_items_bigquery_destination_uri] is specified.
quantiles (List[float]):
Quantiles to use for the `minimize-quantile-loss`
Quantiles to use for the ``minimize-quantile-loss``
TheMichaelHu marked this conversation as resolved.
Show resolved Hide resolved
[AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
this case.

Expand Down Expand Up @@ -4200,6 +4207,37 @@ def run(
Optional. Additional experiment flags for the time series forcasting training.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
hierarchy_group_columns (List[str]):
Optional. A list of time series attribute column names that
define the time series hierarchy. Only one level of hierarchy is
supported, ex. ``region`` for a hierarchy of stores or
``department`` for a hierarchy of products. If multiple columns
are specified, time series will be grouped by their combined
values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
to 5 columns are accepted. If no group columns are specified,
all time series are considered to be part of the same group.
hierarchy_group_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
time series in the same hierarchy group.
hierarchy_temporal_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
the horizon for a single time series.
hierarchy_group_temporal_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
both the horizon and time series in the same hierarchy group.
window_column (str):
Optional. Name of the column that should be used to filter input
rows. The column should contain either booleans or string
booleans; if the value of the row is True, generate a sliding
window from that row.
window_stride_length (int):
Optional. Step length used to generate input examples. Every
``window_stride_length`` rows will be used to generate a sliding
window.
window_max_count (int):
Optional. Number of rows that should be used to generate input
examples. If the total row count is larger than this number, the
input data will be randomly sampled to hit the count.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
Expand Down Expand Up @@ -4254,6 +4292,13 @@ def run(
validation_options=validation_options,
model_display_name=model_display_name,
model_labels=model_labels,
hierarchy_group_columns=hierarchy_group_columns,
hierarchy_group_total_weight=hierarchy_group_total_weight,
hierarchy_temporal_total_weight=hierarchy_temporal_total_weight,
hierarchy_group_temporal_total_weight=hierarchy_group_temporal_total_weight,
window_column=window_column,
window_stride_length=window_stride_length,
window_max_count=window_max_count,
sync=sync,
create_request_timeout=create_request_timeout,
)
Expand Down Expand Up @@ -4286,6 +4331,13 @@ def _run(
budget_milli_node_hours: int = 1000,
model_display_name: Optional[str] = None,
model_labels: Optional[Dict[str, str]] = None,
hierarchy_group_columns: Optional[List[str]] = None,
hierarchy_group_total_weight: Optional[float] = None,
hierarchy_temporal_total_weight: Optional[float] = None,
hierarchy_group_temporal_total_weight: Optional[float] = None,
window_column: Optional[str] = None,
window_stride_length: Optional[int] = None,
window_max_count: Optional[int] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
) -> models.Model:
Expand Down Expand Up @@ -4453,6 +4505,37 @@ def _run(
are allowed.
See https://goo.gl/xmQnxf for more information
and examples of labels.
hierarchy_group_columns (List[str]):
Optional. A list of time series attribute column names that
define the time series hierarchy. Only one level of hierarchy is
supported, ex. ``region`` for a hierarchy of stores or
``department`` for a hierarchy of products. If multiple columns
are specified, time series will be grouped by their combined
values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
to 5 columns are accepted. If no group columns are specified,
all time series are considered to be part of the same group.
hierarchy_group_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
time series in the same hierarchy group.
hierarchy_temporal_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
the horizon for a single time series.
hierarchy_group_temporal_total_weight (float):
Optional. The weight of the loss for predictions aggregated over
both the horizon and time series in the same hierarchy group.
window_column (str):
Optional. Name of the column that should be used to filter input
rows. The column should contain either booleans or string
booleans; if the value of the row is True, generate a sliding
window from that row.
window_stride_length (int):
Optional. Step length used to generate input examples. Every
``window_stride_length`` rows will be used to generate a sliding
window.
window_max_count (int):
Optional. Number of rows that should be used to generate input
examples. If the total row count is larger than this number, the
input data will be randomly sampled to hit the count.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
Expand Down Expand Up @@ -4482,6 +4565,12 @@ def _run(
% column_names
)

window_config = self._create_window_config(
column=window_column,
stride_length=window_stride_length,
max_count=window_max_count,
)

training_task_inputs_dict = {
# required inputs
"targetColumn": target_column,
Expand All @@ -4505,6 +4594,24 @@ def _run(
"optimizationObjective": self._optimization_objective,
}

# TODO(TheMichaelHu): Remove the ifs once the API supports these inputs.
if any(
[
hierarchy_group_columns,
hierarchy_group_total_weight,
hierarchy_temporal_total_weight,
hierarchy_group_temporal_total_weight,
]
):
training_task_inputs_dict["hierarchyConfig"] = {
"groupColumns": hierarchy_group_columns,
"groupTotalWeight": hierarchy_group_total_weight,
"temporalTotalWeight": hierarchy_temporal_total_weight,
"groupTemporalTotalWeight": hierarchy_group_temporal_total_weight,
}
if window_config:
training_task_inputs_dict["windowConfig"] = window_config

final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
if final_export_eval_bq_uri and not final_export_eval_bq_uri.startswith(
"bq://"
Expand Down Expand Up @@ -4582,6 +4689,29 @@ def _add_additional_experiments(self, additional_experiments: List[str]):
"""
self._additional_experiments.extend(additional_experiments)

@staticmethod
def _create_window_config(
column: Optional[str] = None,
stride_length: Optional[int] = None,
max_count: Optional[int] = None,
) -> Optional[Dict[str, Union[int, str]]]:
"""Creates a window config from training job arguments."""
configs = {
"column": column,
"strideLength": stride_length,
"maxCount": max_count,
}
present_configs = {k: v for k, v in configs.items() if v is not None}
if not present_configs:
return None
if len(present_configs) > 1:
raise ValueError(
TheMichaelHu marked this conversation as resolved.
Show resolved Hide resolved
"More than one windowing strategy provided. Make sure only one "
"of window_column, window_stride_length, or window_max_count "
"is specified."
)
return present_configs


class AutoMLImageTrainingJob(_TrainingJob):
_supported_training_schemas = (
Expand Down
Loading