From 587df744e2b6c4b3e1a96ff69937697fe80a97be Mon Sep 17 00:00:00 2001 From: Alexey Volkov Date: Mon, 9 Oct 2023 18:54:47 -0700 Subject: [PATCH] feat: LLM - Added support for multiple chat response candidates PiperOrigin-RevId: 572100735 --- tests/unit/aiplatform/test_language_models.py | 71 +++++++++++++++++++ vertexai/language_models/_language_models.py | 64 +++++++++++------ 2 files changed, 113 insertions(+), 22 deletions(-) diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index 40e7ae8352..7b253da5f8 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -238,6 +238,30 @@ } ], } +_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION = { + "safetyAttributes": [ + { + "scores": [], + "categories": [], + "blocked": False, + }, + { + "scores": [0.1], + "categories": ["Finance"], + "blocked": True, + }, + ], + "candidates": [ + { + "author": "1", + "content": "Chat response 2", + }, + { + "author": "1", + "content": "", + }, + ], +} _TEST_CHAT_PREDICTION_STREAMING = [ { @@ -2076,6 +2100,53 @@ def test_chat_ga(self): assert prediction_parameters["topP"] == message_top_p assert prediction_parameters["stopSequences"] == message_stop_sequences + def test_chat_model_send_message_with_multiple_candidates(self): + """Tests the chat generation model with multiple candidates.""" + + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CHAT_BISON_PUBLISHER_MODEL_DICT + ), + ) as mock_get_publisher_model: + model = language_models.ChatModel.from_pretrained("chat-bison@001") + + mock_get_publisher_model.assert_called_once_with( + name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY + ) + + chat = model.start_chat() + + gca_predict_response1 = gca_prediction_service.PredictResponse() + gca_predict_response1.predictions.append( + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION + ) + + with mock.patch.object( + target=prediction_service_client.PredictionServiceClient, + attribute="predict", + return_value=gca_predict_response1, + ): + message_text1 = "Are my favorite movies based on a book series?" + expected_response_candidates = ( + _TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION["candidates"] + ) + expected_candidate_0 = expected_response_candidates[0]["content"] + expected_candidate_1 = expected_response_candidates[1]["content"] + + response = chat.send_message(message_text1, candidate_count=2) + assert response.text == expected_candidate_0 + assert len(response.candidates) == 2 + assert response.candidates[0].text == expected_candidate_0 + assert response.candidates[1].text == expected_candidate_1 + + assert len(chat.message_history) == 2 + assert chat.message_history[0].author == chat.USER_AUTHOR + assert chat.message_history[0].content == message_text1 + assert chat.message_history[1].author == chat.MODEL_AUTHOR + assert chat.message_history[1].content == expected_candidate_0 + def test_chat_model_send_message_streaming(self): """Tests the chat generation model.""" with mock.patch.object( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 663e552b9e..9a13ae5916 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -1615,6 +1615,7 @@ def _prepare_request( top_k: Optional[int] = None, top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, + candidate_count: Optional[int] = None, ) -> _PredictionRequest: """Prepares a request for the language model. @@ -1629,6 +1630,7 @@ def _prepare_request( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. Uses the value specified when calling `ChatModel.start_chat` by default. stop_sequences: Customized stop sequences to stop the decoding process. + candidate_count: Number of candidates to return. Returns: A `_PredictionRequest` object. @@ -1660,6 +1662,9 @@ def _prepare_request( if stop_sequences: prediction_parameters["stopSequences"] = stop_sequences + if candidate_count is not None: + prediction_parameters["candidateCount"] = candidate_count + message_structs = [] for past_message in self._message_history: message_structs.append( @@ -1697,8 +1702,7 @@ def _parse_chat_prediction_response( cls, prediction_response: aiplatform.models.Prediction, prediction_idx: int = 0, - candidate_idx: int = 0, - ) -> TextGenerationResponse: + ) -> MultiCandidateTextGenerationResponse: """Parses prediction response for chat models. Args: @@ -1707,25 +1711,33 @@ def _parse_chat_prediction_response( candidate_idx: Index of the candidate to parse. Returns: - A `TextGenerationResponse` object. + A `MultiCandidateTextGenerationResponse` object. """ prediction = prediction_response.predictions[prediction_idx] - # ! Note: For chat models, the safetyAttributes is a list. - safety_attributes = prediction["safetyAttributes"][candidate_idx] - return TextGenerationResponse( - text=prediction["candidates"][candidate_idx]["content"] - if prediction.get("candidates") - else None, + candidate_count = len(prediction["candidates"]) + candidates = [] + for candidate_idx in range(candidate_count): + safety_attributes = prediction["safetyAttributes"][candidate_idx] + candidate_response = TextGenerationResponse( + text=prediction["candidates"][candidate_idx]["content"], + _prediction_response=prediction_response, + is_blocked=safety_attributes.get("blocked", False), + safety_attributes=dict( + zip( + # Unlike with normal prediction, in streaming prediction + # categories and scores can be None + safety_attributes.get("categories") or [], + safety_attributes.get("scores") or [], + ) + ), + ) + candidates.append(candidate_response) + return MultiCandidateTextGenerationResponse( + text=candidates[0].text, _prediction_response=prediction_response, - is_blocked=safety_attributes.get("blocked", False), - safety_attributes=dict( - zip( - # Unlike with normal prediction, in streaming prediction - # categories and scores can be None - safety_attributes.get("categories") or [], - safety_attributes.get("scores") or [], - ) - ), + is_blocked=candidates[0].is_blocked, + safety_attributes=candidates[0].safety_attributes, + candidates=candidates, ) def send_message( @@ -1737,7 +1749,8 @@ def send_message( top_k: Optional[int] = None, top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, - ) -> "TextGenerationResponse": + candidate_count: Optional[int] = None, + ) -> "MultiCandidateTextGenerationResponse": """Sends message to the language model and gets a response. Args: @@ -1751,9 +1764,11 @@ def send_message( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. Uses the value specified when calling `ChatModel.start_chat` by default. stop_sequences: Customized stop sequences to stop the decoding process. + candidate_count: Number of candidates to return. Returns: - A `TextGenerationResponse` object that contains the text produced by the model. + A `MultiCandidateTextGenerationResponse` object that contains the + text produced by the model. """ prediction_request = self._prepare_request( message=message, @@ -1762,6 +1777,7 @@ def send_message( top_k=top_k, top_p=top_p, stop_sequences=stop_sequences, + candidate_count=candidate_count, ) prediction_response = self._model._endpoint.predict( @@ -1791,7 +1807,8 @@ async def send_message_async( top_k: Optional[int] = None, top_p: Optional[float] = None, stop_sequences: Optional[List[str]] = None, - ) -> "TextGenerationResponse": + candidate_count: Optional[int] = None, + ) -> "MultiCandidateTextGenerationResponse": """Asynchronously sends message to the language model and gets a response. Args: @@ -1805,9 +1822,11 @@ async def send_message_async( top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95. Uses the value specified when calling `ChatModel.start_chat` by default. stop_sequences: Customized stop sequences to stop the decoding process. + candidate_count: Number of candidates to return. Returns: - A `TextGenerationResponse` object that contains the text produced by the model. + A `MultiCandidateTextGenerationResponse` object that contains + the text produced by the model. """ prediction_request = self._prepare_request( message=message, @@ -1816,6 +1835,7 @@ async def send_message_async( top_k=top_k, top_p=top_p, stop_sequences=stop_sequences, + candidate_count=candidate_count, ) prediction_response = await self._model._endpoint.predict_async(