diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 4395fdca40..6ed1c5d893 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -45,7 +45,7 @@ artifact as gca_artifact, prediction_service as gca_prediction_service, context as gca_context, - endpoint as gca_endpoint, + endpoint_v1 as gca_endpoint, pipeline_job as gca_pipeline_job, pipeline_state as gca_pipeline_state, deployed_model_ref_v1, @@ -1030,6 +1030,11 @@ def get_endpoint_mock(): get_endpoint_mock.return_value = gca_endpoint.Endpoint( display_name="test-display-name", name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME, + deployed_models=[ + gca_endpoint.DeployedModel( + model=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME + ), + ], ) yield get_endpoint_mock @@ -2420,7 +2425,10 @@ def test_text_embedding_ga(self): assert len(vector) == _TEXT_EMBEDDING_VECTOR_LENGTH assert vector == _TEST_TEXT_EMBEDDING_PREDICTION["embeddings"]["values"] - def test_batch_prediction(self): + def test_batch_prediction( + self, + get_endpoint_mock, + ): """Tests batch prediction.""" aiplatform.init( project=_TEST_PROJECT, @@ -2447,7 +2455,29 @@ def test_batch_prediction(self): model_parameters={"temperature": 0.1}, ) mock_create.assert_called_once_with( - model_name="publishers/google/models/text-bison@001", + model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001", + job_display_name=None, + gcs_source="gs://test-bucket/test_table.jsonl", + gcs_destination_prefix="gs://test-bucket/results/", + model_parameters={"temperature": 0.1}, + ) + + # Testing tuned model batch prediction + tuned_model = language_models.TextGenerationModel( + model_id=model._model_id, + endpoint_name=test_constants.EndpointConstants._TEST_ENDPOINT_NAME, + ) + with mock.patch.object( + target=aiplatform.BatchPredictionJob, + attribute="create", + ) as mock_create: + tuned_model.batch_predict( + dataset="gs://test-bucket/test_table.jsonl", + destination_uri_prefix="gs://test-bucket/results/", + model_parameters={"temperature": 0.1}, + ) + mock_create.assert_called_once_with( + model_name=test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME, job_display_name=None, gcs_source="gs://test-bucket/test_table.jsonl", gcs_destination_prefix="gs://test-bucket/results/", @@ -2481,7 +2511,7 @@ def test_batch_prediction_for_text_embedding(self): model_parameters={}, ) mock_create.assert_called_once_with( - model_name="publishers/google/models/textembedding-gecko@001", + model_name=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001", job_display_name=None, gcs_source="gs://test-bucket/test_table.jsonl", gcs_destination_prefix="gs://test-bucket/results/", diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 7f2f15cb7d..f542f52bf8 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -839,11 +839,6 @@ def batch_predict( raise ValueError(f"Unsupported destination_uri: {destination_uri_prefix}") model_name = self._model_resource_name - # TODO(b/284512065): Batch prediction service does not support - # fully qualified publisher model names yet - publishers_index = model_name.index("/publishers/") - if publishers_index > 0: - model_name = model_name[publishers_index + 1 :] job = aiplatform.BatchPredictionJob.create( model_name=model_name,