From 0036ab07004e0c9ae7806c4c2c25f22d5af4a978 Mon Sep 17 00:00:00 2001 From: Michael Hu Date: Fri, 27 May 2022 17:44:20 -0400 Subject: [PATCH] feat: add holiday regions for vertex forecasting (#1253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add holiday regions for Vertex Forecasting. --- Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) Fixes b/232964228 🦕 --- google/cloud/aiplatform/training_jobs.py | 32 ++++++++++++++++++- .../test_automl_forecasting_training_jobs.py | 13 ++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index ce5ddf0f94..a6244e08ca 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -4044,6 +4044,7 @@ def run( window_column: Optional[str] = None, window_stride_length: Optional[int] = None, window_max_count: Optional[int] = None, + holiday_regions: Optional[List[str]] = None, sync: bool = True, create_request_timeout: Optional[float] = None, ) -> models.Model: @@ -4238,10 +4239,23 @@ def run( Optional. Number of rows that should be used to generate input examples. If the total row count is larger than this number, the input data will be randomly sampled to hit the count. + holiday_regions (List[str]): + Optional. The geographical regions to use when creating holiday + features. This option is only allowed when data_granularity_unit + is ``day``. Acceptable values can come from any of the following + levels: + Top level: GLOBAL + Second level: continental regions + NA: North America + JAPAC: Japan and Asia Pacific + EMEA: Europe, the Middle East and Africa + LAC: Latin America and the Caribbean + Third level: countries from ISO 3166-1 Country codes. sync (bool): - Whether to execute this method synchronously. If False, this method + Optional. Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + Returns: model: The trained Vertex AI Model resource or None if training did not produce a Vertex AI Model. @@ -4299,6 +4313,7 @@ def run( window_column=window_column, window_stride_length=window_stride_length, window_max_count=window_max_count, + holiday_regions=holiday_regions, sync=sync, create_request_timeout=create_request_timeout, ) @@ -4338,6 +4353,7 @@ def _run( window_column: Optional[str] = None, window_stride_length: Optional[int] = None, window_max_count: Optional[int] = None, + holiday_regions: Optional[List[str]] = None, sync: bool = True, create_request_timeout: Optional[float] = None, ) -> models.Model: @@ -4536,12 +4552,25 @@ def _run( Optional. Number of rows that should be used to generate input examples. If the total row count is larger than this number, the input data will be randomly sampled to hit the count. + holiday_regions (List[str]): + Optional. The geographical regions to use when creating holiday + features. This option is only allowed when data_granularity_unit + is ``day``. Acceptable values can come from any of the following + levels: + Top level: GLOBAL + Second level: continental regions + NA: North America + JAPAC: Japan and Asia Pacific + EMEA: Europe, the Middle East and Africa + LAC: Latin America and the Caribbean + Third level: countries from ISO 3166-1 Country codes. sync (bool): Whether to execute this method synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. create_request_timeout (float): Optional. The timeout for the create request in seconds. + Returns: model: The trained Vertex AI Model resource or None if training did not produce a Vertex AI Model. @@ -4592,6 +4621,7 @@ def _run( "quantiles": quantiles, "validationOptions": validation_options, "optimizationObjective": self._optimization_objective, + "holidayRegions": holiday_regions, } # TODO(TheMichaelHu): Remove the ifs once the API supports these inputs. diff --git a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py index ecc3f544a0..21ca78da2e 100644 --- a/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py +++ b/tests/unit/aiplatform/test_automl_forecasting_training_jobs.py @@ -88,6 +88,7 @@ _TEST_WINDOW_COLUMN = None _TEST_WINDOW_STRIDE_LENGTH = 1 _TEST_WINDOW_MAX_COUNT = None +_TEST_TRAINING_HOLIDAY_REGIONS = ["GLOBAL"] _TEST_TRAINING_TASK_INPUTS_DICT = { # required inputs "targetColumn": _TEST_TRAINING_TARGET_COLUMN, @@ -122,6 +123,7 @@ "windowConfig": { "strideLength": _TEST_WINDOW_STRIDE_LENGTH, }, + "holidayRegions": _TEST_TRAINING_HOLIDAY_REGIONS, } _TEST_TRAINING_TASK_INPUTS_WITH_ADDITIONAL_EXPERIMENTS = json_format.ParseDict( @@ -322,6 +324,7 @@ def test_run_call_pipeline_service_create( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -417,6 +420,7 @@ def test_run_call_pipeline_service_create_with_timeout( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=180.0, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -494,6 +498,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -571,6 +576,7 @@ def test_run_call_pipeline_if_set_additional_experiments( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -644,6 +650,7 @@ def test_run_called_twice_raises( window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) with pytest.raises(RuntimeError): @@ -675,6 +682,7 @@ def test_run_called_twice_raises( window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) @pytest.mark.parametrize("sync", [True, False]) @@ -722,6 +730,7 @@ def test_run_raises_if_pipeline_fails( window_stride_length=_TEST_WINDOW_STRIDE_LENGTH, window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -805,6 +814,7 @@ def test_splits_fraction( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -900,6 +910,7 @@ def test_splits_timestamp( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -993,6 +1004,7 @@ def test_splits_predefined( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: @@ -1081,6 +1093,7 @@ def test_splits_default( window_max_count=_TEST_WINDOW_MAX_COUNT, sync=sync, create_request_timeout=None, + holiday_regions=_TEST_TRAINING_HOLIDAY_REGIONS, ) if not sync: