diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index f7b4159153..79e4b503fc 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -570,6 +570,10 @@ def test_text_generation(self): ) assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] + assert ( + response.safety_attributes["Violent"] + == _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0] + ) def test_text_generation_ga(self): """Tests the text generation model.""" diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 23ca3fa295..31b743afee 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -15,7 +15,7 @@ """Classes for working with language models.""" import dataclasses -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union from google.cloud import aiplatform from google.cloud.aiplatform import base @@ -198,10 +198,19 @@ def tune_model( @dataclasses.dataclass class TextGenerationResponse: - """TextGenerationResponse represents a response of a language model.""" + """TextGenerationResponse represents a response of a language model. + Attributes: + text: The generated text + is_blocked: Whether the the request was blocked. + safety_attributes: Scores for safety attributes. + Learn more about the safety attributes here: + https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions + """ text: str _prediction_response: Any + is_blocked: bool = False + safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict) def __repr__(self): return self.text @@ -289,13 +298,23 @@ def _batch_predict( parameters=prediction_parameters, ) - return [ - TextGenerationResponse( - text=prediction["content"], - _prediction_response=prediction_response, + results = [] + for prediction in prediction_response.predictions: + safety_attributes_dict = prediction.get("safetyAttributes", {}) + results.append( + TextGenerationResponse( + text=prediction["content"], + _prediction_response=prediction_response, + is_blocked=safety_attributes_dict.get("blocked", False), + safety_attributes=dict( + zip( + safety_attributes_dict.get("categories", []), + safety_attributes_dict.get("scores", []), + ) + ), + ) ) - for prediction in prediction_response.predictions - ] + return results _TextGenerationModel = TextGenerationModel @@ -690,9 +709,20 @@ def send_message( parameters=prediction_parameters, ) + prediction = prediction_response.predictions[0] + safety_attributes = prediction["safetyAttributes"] response_obj = TextGenerationResponse( - text=prediction_response.predictions[0]["candidates"][0]["content"], + text=prediction["candidates"][0]["content"] + if prediction.get("candidates") + else None, _prediction_response=prediction_response, + is_blocked=safety_attributes.get("blocked", False), + safety_attributes=dict( + zip( + safety_attributes.get("categories", []), + safety_attributes.get("scores", []), + ) + ), ) response_text = response_obj.text