diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index bcfb9a9781..aba79623fe 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -2038,6 +2038,7 @@ def test_code_chat(self): ) code_chat = model.start_chat( + context="We're working on large-scale production system.", max_output_tokens=128, temperature=0.2, stop_sequences=["\n"], diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 084bb4c9f7..820fe5a5c5 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -1287,6 +1287,7 @@ class CodeChatModel(_ChatModelBase): code_chat_model = CodeChatModel.from_pretrained("codechat-bison@001") code_chat = code_chat_model.start_chat( + context="I'm writing a large-scale enterprise application.", max_output_tokens=128, temperature=0.2, ) @@ -1301,6 +1302,7 @@ class CodeChatModel(_ChatModelBase): def start_chat( self, *, + context: Optional[str] = None, max_output_tokens: Optional[int] = None, temperature: Optional[float] = None, message_history: Optional[List[ChatMessage]] = None, @@ -1309,6 +1311,9 @@ def start_chat( """Starts a chat session with the code chat model. Args: + context: Context shapes how the model responds throughout the conversation. + For example, you can use context to specify words the model can or + cannot use, topics to focus on or avoid, or the response format or style. max_output_tokens: Max length of the output text in tokens. Range: [1, 1000]. temperature: Controls the randomness of predictions. Range: [0, 1]. stop_sequences: Customized stop sequences to stop the decoding process. @@ -1318,6 +1323,7 @@ def start_chat( """ return CodeChatSession( model=self, + context=context, max_output_tokens=max_output_tokens, temperature=temperature, message_history=message_history, @@ -1653,6 +1659,7 @@ class CodeChatSession(_ChatSessionBase): def __init__( self, model: CodeChatModel, + context: Optional[str] = None, max_output_tokens: Optional[int] = None, temperature: Optional[float] = None, message_history: Optional[List[ChatMessage]] = None, @@ -1660,6 +1667,7 @@ def __init__( ): super().__init__( model=model, + context=context, max_output_tokens=max_output_tokens, temperature=temperature, message_history=message_history,