Skip to content

Commit eef66fc

Browse files
committed
add seq2seq e2e tests
1 parent 3f84a96 commit eef66fc

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

tests/system/aiplatform/test_e2e_forecasting.py

+40-13
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ class TestEndToEndForecasting(e2e_base.TestEndToEnd):
3737

3838
def test_end_to_end_forecasting(self, shared_state):
3939
"""Builds a dataset, trains models, and gets batch predictions."""
40-
ds = None
41-
automl_job = None
42-
automl_model = None
43-
automl_batch_prediction_job = None
40+
resources = []
4441

4542
aiplatform.init(
4643
project=e2e_base._PROJECT,
@@ -69,12 +66,17 @@ def test_end_to_end_forecasting(self, shared_state):
6966
}
7067

7168
# Define both training jobs
72-
# TODO(humichael): Add seq2seq job.
7369
automl_job = aiplatform.AutoMLForecastingTrainingJob(
7470
display_name=self._make_display_name("train-housing-automl"),
7571
optimization_objective="minimize-rmse",
7672
column_specs=column_specs,
7773
)
74+
seq2seq_job = aiplatform.SequenceToSequencePlusForecastingTrainingJob(
75+
display_name=self._make_display_name("train-housing-seq2seq"),
76+
optimization_objective="minimize-rmse",
77+
column_specs=column_specs,
78+
)
79+
resources.extend([automl_job, seq2seq_job])
7880

7981
# Kick off both training jobs, AutoML job will take approx one hour
8082
# to run.
@@ -94,6 +96,23 @@ def test_end_to_end_forecasting(self, shared_state):
9496
model_display_name=self._make_display_name("automl-liquor-model"),
9597
sync=False,
9698
)
99+
seq2seq_model = seq2seq_job.run(
100+
dataset=ds,
101+
target_column=target_column,
102+
time_column=time_column,
103+
time_series_identifier_column=time_series_identifier_column,
104+
available_at_forecast_columns=[time_column],
105+
unavailable_at_forecast_columns=[target_column],
106+
time_series_attribute_columns=["city", "zip_code", "county"],
107+
forecast_horizon=30,
108+
context_window=30,
109+
data_granularity_unit="day",
110+
data_granularity_count=1,
111+
budget_milli_node_hours=1000,
112+
model_display_name=self._make_display_name("seq2seq-liquor-model"),
113+
sync=False,
114+
)
115+
resources.extend([automl_model, seq2seq_model])
97116

98117
automl_batch_prediction_job = automl_model.batch_predict(
99118
job_display_name=self._make_display_name("automl-liquor-model"),
@@ -105,8 +124,22 @@ def test_end_to_end_forecasting(self, shared_state):
105124
),
106125
sync=False,
107126
)
127+
seq2seq_batch_prediction_job = seq2seq_model.batch_predict(
128+
job_display_name=self._make_display_name("seq2seq-liquor-model"),
129+
instances_format="bigquery",
130+
machine_type="n1-standard-4",
131+
bigquery_source=_PREDICTION_DATASET_BQ_PATH,
132+
gcs_destination_prefix=(
133+
f'gs://{shared_state["staging_bucket_name"]}/bp_results/'
134+
),
135+
sync=False,
136+
)
137+
resources.extend(
138+
[automl_batch_prediction_job, seq2seq_batch_prediction_job]
139+
)
108140

109141
automl_batch_prediction_job.wait()
142+
seq2seq_batch_prediction_job.wait()
110143

111144
assert (
112145
automl_job.state
@@ -117,11 +150,5 @@ def test_end_to_end_forecasting(self, shared_state):
117150
== job_state.JobState.JOB_STATE_SUCCEEDED
118151
)
119152
finally:
120-
if ds is not None:
121-
ds.delete()
122-
if automl_job is not None:
123-
automl_job.delete()
124-
if automl_model is not None:
125-
automl_model.delete()
126-
if automl_batch_prediction_job is not None:
127-
automl_batch_prediction_job.delete()
153+
for resource in resources:
154+
resource.delete()

0 commit comments

Comments
 (0)