@@ -37,10 +37,7 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):
37
37
38
38
def test_end_to_end_forecasting (self , shared_state ):
39
39
"""Builds a dataset, trains models, and gets batch predictions."""
40
- ds = None
41
- automl_job = None
42
- automl_model = None
43
- automl_batch_prediction_job = None
40
+ resources = []
44
41
45
42
aiplatform .init (
46
43
project = e2e_base ._PROJECT ,
@@ -69,12 +66,17 @@ def test_end_to_end_forecasting(self, shared_state):
69
66
}
70
67
71
68
# Define both training jobs
72
- # TODO(humichael): Add seq2seq job.
73
69
automl_job = aiplatform .AutoMLForecastingTrainingJob (
74
70
display_name = self ._make_display_name ("train-housing-automl" ),
75
71
optimization_objective = "minimize-rmse" ,
76
72
column_specs = column_specs ,
77
73
)
74
+ seq2seq_job = aiplatform .SequenceToSequencePlusForecastingTrainingJob (
75
+ display_name = self ._make_display_name ("train-housing-seq2seq" ),
76
+ optimization_objective = "minimize-rmse" ,
77
+ column_specs = column_specs ,
78
+ )
79
+ resources .extend ([automl_job , seq2seq_job ])
78
80
79
81
# Kick off both training jobs, AutoML job will take approx one hour
80
82
# to run.
@@ -94,6 +96,23 @@ def test_end_to_end_forecasting(self, shared_state):
94
96
model_display_name = self ._make_display_name ("automl-liquor-model" ),
95
97
sync = False ,
96
98
)
99
+ seq2seq_model = seq2seq_job .run (
100
+ dataset = ds ,
101
+ target_column = target_column ,
102
+ time_column = time_column ,
103
+ time_series_identifier_column = time_series_identifier_column ,
104
+ available_at_forecast_columns = [time_column ],
105
+ unavailable_at_forecast_columns = [target_column ],
106
+ time_series_attribute_columns = ["city" , "zip_code" , "county" ],
107
+ forecast_horizon = 30 ,
108
+ context_window = 30 ,
109
+ data_granularity_unit = "day" ,
110
+ data_granularity_count = 1 ,
111
+ budget_milli_node_hours = 1000 ,
112
+ model_display_name = self ._make_display_name ("seq2seq-liquor-model" ),
113
+ sync = False ,
114
+ )
115
+ resources .extend ([automl_model , seq2seq_model ])
97
116
98
117
automl_batch_prediction_job = automl_model .batch_predict (
99
118
job_display_name = self ._make_display_name ("automl-liquor-model" ),
@@ -105,8 +124,22 @@ def test_end_to_end_forecasting(self, shared_state):
105
124
),
106
125
sync = False ,
107
126
)
127
+ seq2seq_batch_prediction_job = seq2seq_model .batch_predict (
128
+ job_display_name = self ._make_display_name ("seq2seq-liquor-model" ),
129
+ instances_format = "bigquery" ,
130
+ machine_type = "n1-standard-4" ,
131
+ bigquery_source = _PREDICTION_DATASET_BQ_PATH ,
132
+ gcs_destination_prefix = (
133
+ f'gs://{ shared_state ["staging_bucket_name" ]} /bp_results/'
134
+ ),
135
+ sync = False ,
136
+ )
137
+ resources .extend (
138
+ [automl_batch_prediction_job , seq2seq_batch_prediction_job ]
139
+ )
108
140
109
141
automl_batch_prediction_job .wait ()
142
+ seq2seq_batch_prediction_job .wait ()
110
143
111
144
assert (
112
145
automl_job .state
@@ -117,11 +150,5 @@ def test_end_to_end_forecasting(self, shared_state):
117
150
== job_state .JobState .JOB_STATE_SUCCEEDED
118
151
)
119
152
finally :
120
- if ds is not None :
121
- ds .delete ()
122
- if automl_job is not None :
123
- automl_job .delete ()
124
- if automl_model is not None :
125
- automl_model .delete ()
126
- if automl_batch_prediction_job is not None :
127
- automl_batch_prediction_job .delete ()
153
+ for resource in resources :
154
+ resource .delete ()
0 commit comments