-
Notifications
You must be signed in to change notification settings - Fork 350
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Service Account support to BatchPredictionJob
COPYBARA_INTEGRATE_REVIEW=#1872 from cymarechal-devoteam:feature/batch-prediction/service-account 4f015f3 PiperOrigin-RevId: 501301075
- Loading branch information
1 parent
369a0cc
commit deba06b
Showing
6 changed files
with
52 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -359,10 +359,11 @@ To create a batch prediction job: | |
batch_prediction_job = model.batch_predict( | ||
job_display_name='my-batch-prediction-job', | ||
instances_format='csv' | ||
instances_format='csv', | ||
machine_type='n1-standard-4', | ||
gcs_source=['gs://path/to/my/file.csv'] | ||
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/' | ||
gcs_source=['gs://path/to/my/file.csv'], | ||
gcs_destination_prefix='gs://path/to/my/batch_prediction/results/', | ||
service_account='[email protected]' | ||
) | ||
You can also create a batch prediction job asynchronously by including the `sync=False` argument: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -284,10 +284,11 @@ To create a batch prediction job: | |
batch_prediction_job = model.batch_predict( | ||
job_display_name='my-batch-prediction-job', | ||
instances_format='csv' | ||
instances_format='csv', | ||
machine_type='n1-standard-4', | ||
gcs_source=['gs://path/to/my/file.csv'] | ||
gcs_destination_prefix='gs://path/to/by/batch_prediction/results/' | ||
gcs_source=['gs://path/to/my/file.csv'], | ||
gcs_destination_prefix='gs://path/to/my/batch_prediction/results/', | ||
service_account='[email protected]' | ||
) | ||
You can also create a batch prediction job asynchronously by including the `sync=False` argument: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,6 +76,8 @@ | |
_TEST_BQ_JOB_ID = "123459876" | ||
_TEST_BQ_MAX_RESULTS = 100 | ||
_TEST_GCS_BUCKET_NAME = "my-bucket" | ||
_TEST_SERVICE_ACCOUNT = "[email protected]" | ||
|
||
|
||
_TEST_BQ_PATH = f"bq://{_TEST_BQ_PROJECT_ID}.{_TEST_BQ_DATASET_ID}" | ||
_TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}" | ||
|
@@ -719,6 +721,7 @@ def test_batch_predict_gcs_source_and_dest( | |
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, | ||
sync=sync, | ||
create_request_timeout=None, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
batch_prediction_job.wait_for_resource_creation() | ||
|
@@ -741,6 +744,7 @@ def test_batch_predict_gcs_source_and_dest( | |
), | ||
predictions_format="jsonl", | ||
), | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
create_batch_prediction_job_mock.assert_called_once_with( | ||
|
@@ -766,6 +770,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout( | |
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, | ||
sync=sync, | ||
create_request_timeout=180.0, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
batch_prediction_job.wait_for_resource_creation() | ||
|
@@ -788,6 +793,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout( | |
), | ||
predictions_format="jsonl", | ||
), | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
create_batch_prediction_job_mock.assert_called_once_with( | ||
|
@@ -812,6 +818,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set( | |
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, | ||
sync=sync, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
batch_prediction_job.wait_for_resource_creation() | ||
|
@@ -834,6 +841,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set( | |
), | ||
predictions_format="jsonl", | ||
), | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
create_batch_prediction_job_mock.assert_called_once_with( | ||
|
@@ -855,6 +863,7 @@ def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock): | |
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, | ||
sync=False, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
batch_prediction_job.wait_for_resource_creation() | ||
|
@@ -881,6 +890,7 @@ def test_batch_predict_gcs_source_bq_dest( | |
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, | ||
sync=sync, | ||
create_request_timeout=None, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
batch_prediction_job.wait_for_resource_creation() | ||
|
@@ -908,6 +918,7 @@ def test_batch_predict_gcs_source_bq_dest( | |
), | ||
predictions_format="bigquery", | ||
), | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
create_batch_prediction_job_mock.assert_called_once_with( | ||
|
@@ -946,6 +957,7 @@ def test_batch_predict_with_all_args( | |
sync=sync, | ||
create_request_timeout=None, | ||
batch_size=_TEST_BATCH_SIZE, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
batch_prediction_job.wait_for_resource_creation() | ||
|
@@ -986,6 +998,7 @@ def test_batch_predict_with_all_args( | |
parameters=_TEST_EXPLANATION_PARAMETERS, | ||
), | ||
labels=_TEST_LABEL, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
create_batch_prediction_job_with_explanations_mock.assert_called_once_with( | ||
|
@@ -1047,6 +1060,7 @@ def test_batch_predict_with_all_args_and_model_monitoring( | |
model_monitoring_objective_config=mm_obj_cfg, | ||
model_monitoring_alert_config=mm_alert_cfg, | ||
analysis_instance_schema_uri="", | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
batch_prediction_job.wait_for_resource_creation() | ||
|
@@ -1086,6 +1100,7 @@ def test_batch_predict_with_all_args_and_model_monitoring( | |
generate_explanation=True, | ||
model_monitoring_config=_TEST_MODEL_MONITORING_CFG, | ||
labels=_TEST_LABEL, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
create_batch_prediction_job_v1beta1_mock.assert_called_once_with( | ||
parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", | ||
|
@@ -1103,6 +1118,7 @@ def test_batch_predict_create_fails(self): | |
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, | ||
sync=False, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
with pytest.raises(RuntimeError) as e: | ||
|
@@ -1143,6 +1159,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock): | |
model_name=_TEST_MODEL_NAME, | ||
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, | ||
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
assert e.match(regexp=r"source") | ||
|
@@ -1159,6 +1176,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): | |
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX, | ||
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
assert e.match(regexp=r"source") | ||
|
@@ -1173,6 +1191,7 @@ def test_batch_predict_no_destination(self): | |
model_name=_TEST_MODEL_NAME, | ||
job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, | ||
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
assert e.match(regexp=r"destination") | ||
|
@@ -1189,6 +1208,7 @@ def test_batch_predict_wrong_instance_format(self): | |
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
instances_format="wrong", | ||
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
assert e.match(regexp=r"accepted instances format") | ||
|
@@ -1205,6 +1225,7 @@ def test_batch_predict_wrong_prediction_format(self): | |
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
predictions_format="wrong", | ||
bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
|
||
assert e.match(regexp=r"accepted prediction format") | ||
|
@@ -1222,6 +1243,7 @@ def test_batch_predict_job_with_versioned_model( | |
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, | ||
sync=True, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
assert ( | ||
create_batch_prediction_job_mock.call_args_list[0][1][ | ||
|
@@ -1237,6 +1259,7 @@ def test_batch_predict_job_with_versioned_model( | |
gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, | ||
gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, | ||
sync=True, | ||
service_account=_TEST_SERVICE_ACCOUNT, | ||
) | ||
assert ( | ||
create_batch_prediction_job_mock.call_args_list[0][1][ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters