Skip to content

Commit

Permalink
feat: LLM - Exposed the safety attributes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 541035595
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 16, 2023
1 parent 21e48fe commit 01ba3ca
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
4 changes: 4 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def test_text_generation(self):
)

assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
assert (
response.safety_attributes["Violent"]
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
)

def test_text_generation_ga(self):
"""Tests the text generation model."""
Expand Down
48 changes: 39 additions & 9 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Classes for working with language models."""

import dataclasses
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

from google.cloud import aiplatform
from google.cloud.aiplatform import base
Expand Down Expand Up @@ -198,10 +198,19 @@ def tune_model(

@dataclasses.dataclass
class TextGenerationResponse:
"""TextGenerationResponse represents a response of a language model."""
"""TextGenerationResponse represents a response of a language model.
Attributes:
text: The generated text
is_blocked: Whether the the request was blocked.
safety_attributes: Scores for safety attributes.
Learn more about the safety attributes here:
https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_descriptions
"""

text: str
_prediction_response: Any
is_blocked: bool = False
safety_attributes: Dict[str, float] = dataclasses.field(default_factory=dict)

def __repr__(self):
return self.text
Expand Down Expand Up @@ -289,13 +298,23 @@ def _batch_predict(
parameters=prediction_parameters,
)

return [
TextGenerationResponse(
text=prediction["content"],
_prediction_response=prediction_response,
results = []
for prediction in prediction_response.predictions:
safety_attributes_dict = prediction.get("safetyAttributes", {})
results.append(
TextGenerationResponse(
text=prediction["content"],
_prediction_response=prediction_response,
is_blocked=safety_attributes_dict.get("blocked", False),
safety_attributes=dict(
zip(
safety_attributes_dict.get("categories", []),
safety_attributes_dict.get("scores", []),
)
),
)
)
for prediction in prediction_response.predictions
]
return results


_TextGenerationModel = TextGenerationModel
Expand Down Expand Up @@ -690,9 +709,20 @@ def send_message(
parameters=prediction_parameters,
)

prediction = prediction_response.predictions[0]
safety_attributes = prediction["safetyAttributes"]
response_obj = TextGenerationResponse(
text=prediction_response.predictions[0]["candidates"][0]["content"],
text=prediction["candidates"][0]["content"]
if prediction.get("candidates")
else None,
_prediction_response=prediction_response,
is_blocked=safety_attributes.get("blocked", False),
safety_attributes=dict(
zip(
safety_attributes.get("categories", []),
safety_attributes.get("scores", []),
)
),
)
response_text = response_obj.text

Expand Down

0 comments on commit 01ba3ca

Please sign in to comment.