From deba06b938afa695b5fb2d8184647109913abd7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cyril=20Mar=C3=A9chal?= Date: Wed, 11 Jan 2023 09:21:18 -0800 Subject: [PATCH] feat: add Service Account support to BatchPredictionJob COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/1872 from cymarechal-devoteam:feature/batch-prediction/service-account 4f015f3f8a8c0dbdb76511732f45dba809aa1dec PiperOrigin-RevId: 501301075 --- README.rst | 7 ++++--- docs/README.rst | 7 ++++--- google/cloud/aiplatform/jobs.py | 7 +++++++ google/cloud/aiplatform/models.py | 5 +++++ tests/unit/aiplatform/test_jobs.py | 23 +++++++++++++++++++++++ tests/unit/aiplatform/test_models.py | 9 +++++++++ 6 files changed, 52 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index 8a95b43fd0..7ebab90d30 100644 --- a/README.rst +++ b/README.rst @@ -359,10 +359,11 @@ To create a batch prediction job: batch_prediction_job = model.batch_predict( job_display_name='my-batch-prediction-job', - instances_format='csv' + instances_format='csv', machine_type='n1-standard-4', - gcs_source=['gs://path/to/my/file.csv'] - gcs_destination_prefix='gs://path/to/by/batch_prediction/results/' + gcs_source=['gs://path/to/my/file.csv'], + gcs_destination_prefix='gs://path/to/my/batch_prediction/results/', + service_account='my-sa@my-project.iam.gserviceaccount.com' ) You can also create a batch prediction job asynchronously by including the `sync=False` argument: diff --git a/docs/README.rst b/docs/README.rst index 1288053eb2..78821c3395 100644 --- a/docs/README.rst +++ b/docs/README.rst @@ -284,10 +284,11 @@ To create a batch prediction job: batch_prediction_job = model.batch_predict( job_display_name='my-batch-prediction-job', - instances_format='csv' + instances_format='csv', machine_type='n1-standard-4', - gcs_source=['gs://path/to/my/file.csv'] - gcs_destination_prefix='gs://path/to/by/batch_prediction/results/' + gcs_source=['gs://path/to/my/file.csv'], + gcs_destination_prefix='gs://path/to/my/batch_prediction/results/', + service_account='my-sa@my-project.iam.gserviceaccount.com' ) You can also create a batch prediction job asynchronously by including the `sync=False` argument: diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index e7f5076823..9ac35fdb17 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -403,6 +403,7 @@ def create( "aiplatform.model_monitoring.AlertConfig" ] = None, analysis_instance_schema_uri: Optional[str] = None, + service_account: Optional[str] = None, ) -> "BatchPredictionJob": """Create a batch prediction job. @@ -586,6 +587,9 @@ def create( and TFDV instance, this field can be used to override the schema. For models trained with Vertex AI, this field must be set as all the fields in predict instance formatted as string. + service_account (str): + Optional. Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. Returns: (jobs.BatchPredictionJob): Instantiated representation of the created batch prediction job. @@ -745,6 +749,9 @@ def create( ) gapic_batch_prediction_job.explanation_spec = explanation_spec + if service_account: + gapic_batch_prediction_job.service_account = service_account + empty_batch_prediction_job = cls._empty_constructor( project=project, location=location, diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 3078da8e6d..77fb533258 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -3511,6 +3511,7 @@ def batch_predict( sync: bool = True, create_request_timeout: Optional[float] = None, batch_size: Optional[int] = None, + service_account: Optional[str] = None, ) -> jobs.BatchPredictionJob: """Creates a batch prediction job using this Model and outputs prediction results to the provided destination prefix in the specified @@ -3673,6 +3674,9 @@ def batch_predict( but too high value will result in a whole batch not fitting in a machine's memory, and the whole operation will fail. The default value is 64. + service_account (str): + Optional. Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. Returns: job (jobs.BatchPredictionJob): @@ -3705,6 +3709,7 @@ def batch_predict( encryption_spec_key_name=encryption_spec_key_name, sync=sync, create_request_timeout=create_request_timeout, + service_account=service_account, ) @classmethod diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index a38bf53a8a..c99d33da5a 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -76,6 +76,8 @@ _TEST_BQ_JOB_ID = "123459876" _TEST_BQ_MAX_RESULTS = 100 _TEST_GCS_BUCKET_NAME = "my-bucket" +_TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com" + _TEST_BQ_PATH = f"bq://{_TEST_BQ_PROJECT_ID}.{_TEST_BQ_DATASET_ID}" _TEST_GCS_BUCKET_PATH = f"gs://{_TEST_GCS_BUCKET_NAME}" @@ -719,6 +721,7 @@ def test_batch_predict_gcs_source_and_dest( gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, create_request_timeout=None, + service_account=_TEST_SERVICE_ACCOUNT, ) batch_prediction_job.wait_for_resource_creation() @@ -741,6 +744,7 @@ def test_batch_predict_gcs_source_and_dest( ), predictions_format="jsonl", ), + service_account=_TEST_SERVICE_ACCOUNT, ) create_batch_prediction_job_mock.assert_called_once_with( @@ -766,6 +770,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout( gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, create_request_timeout=180.0, + service_account=_TEST_SERVICE_ACCOUNT, ) batch_prediction_job.wait_for_resource_creation() @@ -788,6 +793,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout( ), predictions_format="jsonl", ), + service_account=_TEST_SERVICE_ACCOUNT, ) create_batch_prediction_job_mock.assert_called_once_with( @@ -812,6 +818,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set( gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, + service_account=_TEST_SERVICE_ACCOUNT, ) batch_prediction_job.wait_for_resource_creation() @@ -834,6 +841,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set( ), predictions_format="jsonl", ), + service_account=_TEST_SERVICE_ACCOUNT, ) create_batch_prediction_job_mock.assert_called_once_with( @@ -855,6 +863,7 @@ def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock): gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=False, + service_account=_TEST_SERVICE_ACCOUNT, ) batch_prediction_job.wait_for_resource_creation() @@ -881,6 +890,7 @@ def test_batch_predict_gcs_source_bq_dest( bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, sync=sync, create_request_timeout=None, + service_account=_TEST_SERVICE_ACCOUNT, ) batch_prediction_job.wait_for_resource_creation() @@ -908,6 +918,7 @@ def test_batch_predict_gcs_source_bq_dest( ), predictions_format="bigquery", ), + service_account=_TEST_SERVICE_ACCOUNT, ) create_batch_prediction_job_mock.assert_called_once_with( @@ -946,6 +957,7 @@ def test_batch_predict_with_all_args( sync=sync, create_request_timeout=None, batch_size=_TEST_BATCH_SIZE, + service_account=_TEST_SERVICE_ACCOUNT, ) batch_prediction_job.wait_for_resource_creation() @@ -986,6 +998,7 @@ def test_batch_predict_with_all_args( parameters=_TEST_EXPLANATION_PARAMETERS, ), labels=_TEST_LABEL, + service_account=_TEST_SERVICE_ACCOUNT, ) create_batch_prediction_job_with_explanations_mock.assert_called_once_with( @@ -1047,6 +1060,7 @@ def test_batch_predict_with_all_args_and_model_monitoring( model_monitoring_objective_config=mm_obj_cfg, model_monitoring_alert_config=mm_alert_cfg, analysis_instance_schema_uri="", + service_account=_TEST_SERVICE_ACCOUNT, ) batch_prediction_job.wait_for_resource_creation() @@ -1086,6 +1100,7 @@ def test_batch_predict_with_all_args_and_model_monitoring( generate_explanation=True, model_monitoring_config=_TEST_MODEL_MONITORING_CFG, labels=_TEST_LABEL, + service_account=_TEST_SERVICE_ACCOUNT, ) create_batch_prediction_job_v1beta1_mock.assert_called_once_with( parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", @@ -1103,6 +1118,7 @@ def test_batch_predict_create_fails(self): gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, sync=False, + service_account=_TEST_SERVICE_ACCOUNT, ) with pytest.raises(RuntimeError) as e: @@ -1143,6 +1159,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock): model_name=_TEST_MODEL_NAME, job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + service_account=_TEST_SERVICE_ACCOUNT, ) assert e.match(regexp=r"source") @@ -1159,6 +1176,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock): gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, bigquery_source=_TEST_BATCH_PREDICTION_BQ_PREFIX, bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + service_account=_TEST_SERVICE_ACCOUNT, ) assert e.match(regexp=r"source") @@ -1173,6 +1191,7 @@ def test_batch_predict_no_destination(self): model_name=_TEST_MODEL_NAME, job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + service_account=_TEST_SERVICE_ACCOUNT, ) assert e.match(regexp=r"destination") @@ -1189,6 +1208,7 @@ def test_batch_predict_wrong_instance_format(self): gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, instances_format="wrong", bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + service_account=_TEST_SERVICE_ACCOUNT, ) assert e.match(regexp=r"accepted instances format") @@ -1205,6 +1225,7 @@ def test_batch_predict_wrong_prediction_format(self): gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, predictions_format="wrong", bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, + service_account=_TEST_SERVICE_ACCOUNT, ) assert e.match(regexp=r"accepted prediction format") @@ -1222,6 +1243,7 @@ def test_batch_predict_job_with_versioned_model( gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=True, + service_account=_TEST_SERVICE_ACCOUNT, ) assert ( create_batch_prediction_job_mock.call_args_list[0][1][ @@ -1237,6 +1259,7 @@ def test_batch_predict_job_with_versioned_model( gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=True, + service_account=_TEST_SERVICE_ACCOUNT, ) assert ( create_batch_prediction_job_mock.call_args_list[0][1][ diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index 5a1ad9c741..2b3f6a5276 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -1644,6 +1644,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, create_request_timeout=None, + service_account=_TEST_SERVICE_ACCOUNT, ) if not sync: @@ -1669,6 +1670,7 @@ def test_init_aiplatform_with_encryption_key_name_and_batch_predict_gcs_source_a predictions_format="jsonl", ), encryption_spec=_TEST_ENCRYPTION_SPEC, + service_account=_TEST_SERVICE_ACCOUNT, ) ) @@ -1693,6 +1695,7 @@ def test_batch_predict_gcs_source_and_dest( gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, create_request_timeout=None, + service_account=_TEST_SERVICE_ACCOUNT, ) if not sync: @@ -1711,6 +1714,7 @@ def test_batch_predict_with_version(self, sync, create_batch_prediction_job_mock gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, sync=sync, create_request_timeout=None, + service_account=_TEST_SERVICE_ACCOUNT, ) if not sync: @@ -1733,6 +1737,7 @@ def test_batch_predict_with_version(self, sync, create_batch_prediction_job_mock ), predictions_format="jsonl", ), + service_account=_TEST_SERVICE_ACCOUNT, ) ) @@ -1757,6 +1762,7 @@ def test_batch_predict_gcs_source_bq_dest( bigquery_destination_prefix=_TEST_BATCH_PREDICTION_BQ_PREFIX, sync=sync, create_request_timeout=None, + service_account=_TEST_SERVICE_ACCOUNT, ) if not sync: @@ -1781,6 +1787,7 @@ def test_batch_predict_gcs_source_bq_dest( ), predictions_format="bigquery", ), + service_account=_TEST_SERVICE_ACCOUNT, ) ) @@ -1817,6 +1824,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn sync=sync, create_request_timeout=None, batch_size=_TEST_BATCH_SIZE, + service_account=_TEST_SERVICE_ACCOUNT, ) if not sync: @@ -1857,6 +1865,7 @@ def test_batch_predict_with_all_args(self, create_batch_prediction_job_mock, syn ), labels=_TEST_LABEL, encryption_spec=_TEST_ENCRYPTION_SPEC, + service_account=_TEST_SERVICE_ACCOUNT, ) create_batch_prediction_job_mock.assert_called_once_with(