From 4a531437bb5bad3aa474b7359f841341c82267c5 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 18 Dec 2024 13:54:07 -0800 Subject: [PATCH] core[patch], openai[patch]: Handle OpenAI developer msg (#28794) - Convert developer openai messages to SystemMessage - store additional_kwargs={"__openai_role__": "developer"} so that the correct role can be reconstructed if needed - update ChatOpenAI to read in openai_role --------- Co-authored-by: Erick Friis --- libs/core/langchain_core/messages/utils.py | 13 +++++--- .../tests/unit_tests/messages/test_utils.py | 21 ++++++++++++ .../langchain_openai/chat_models/base.py | 27 +++++++++++++--- .../chat_models/test_base.py | 13 ++++++++ .../tests/unit_tests/chat_models/test_base.py | 32 +++++++++++++++++++ 5 files changed, 96 insertions(+), 10 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 452d59a5d78bf..435bd1e2f1e07 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -221,14 +221,14 @@ def _create_message_from_message_type( tool_call_id: (str) the tool call id. Default is None. tool_calls: (list[dict[str, Any]]) the tool calls. Default is None. id: (str) the id of the message. Default is None. - **additional_kwargs: (dict[str, Any]) additional keyword arguments. + additional_kwargs: (dict[str, Any]) additional keyword arguments. Returns: a message of the appropriate type. Raises: ValueError: if the message type is not one of "human", "user", "ai", - "assistant", "system", "function", or "tool". + "assistant", "function", "tool", "system", or "developer". """ kwargs: dict[str, Any] = {} if name is not None: @@ -261,7 +261,10 @@ def _create_message_from_message_type( message: BaseMessage = HumanMessage(content=content, **kwargs) elif message_type in ("ai", "assistant"): message = AIMessage(content=content, **kwargs) - elif message_type == "system": + elif message_type in ("system", "developer"): + if message_type == "developer": + kwargs["additional_kwargs"] = kwargs.get("additional_kwargs") or {} + kwargs["additional_kwargs"]["__openai_role__"] = "developer" message = SystemMessage(content=content, **kwargs) elif message_type == "function": message = FunctionMessage(content=content, **kwargs) @@ -273,7 +276,7 @@ def _create_message_from_message_type( else: msg = ( f"Unexpected message type: '{message_type}'. Use one of 'human'," - f" 'user', 'ai', 'assistant', 'function', 'tool', or 'system'." + f" 'user', 'ai', 'assistant', 'function', 'tool', 'system', or 'developer'." ) msg = create_message(message=msg, error_code=ErrorCode.MESSAGE_COERCION_FAILURE) raise ValueError(msg) @@ -1385,7 +1388,7 @@ def _get_message_openai_role(message: BaseMessage) -> str: elif isinstance(message, ToolMessage): return "tool" elif isinstance(message, SystemMessage): - return "system" + return message.additional_kwargs.get("__openai_role__", "system") elif isinstance(message, FunctionMessage): return "function" elif isinstance(message, ChatMessage): diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 9f4a9a4cc6c36..3a32c2984952a 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -469,6 +469,7 @@ def test_convert_to_messages() -> None: message_like: list = [ # BaseMessage SystemMessage("1"), + SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}), HumanMessage([{"type": "image_url", "image_url": {"url": "2.1"}}], name="2.2"), AIMessage( [ @@ -503,6 +504,7 @@ def test_convert_to_messages() -> None: ToolMessage("5.1", tool_call_id="5.2", name="5.3"), # OpenAI dict {"role": "system", "content": "6"}, + {"role": "developer", "content": "6.1"}, { "role": "user", "content": [{"type": "image_url", "image_url": {"url": "7.1"}}], @@ -526,6 +528,7 @@ def test_convert_to_messages() -> None: {"role": "tool", "content": "10.1", "tool_call_id": "10.2"}, # Tuple/List ("system", "11.1"), + ("developer", "11.2"), ("human", [{"type": "image_url", "image_url": {"url": "12.1"}}]), ( "ai", @@ -551,6 +554,9 @@ def test_convert_to_messages() -> None: ] expected = [ SystemMessage(content="1"), + SystemMessage( + content="1.1", additional_kwargs={"__openai_role__": "developer"} + ), HumanMessage( content=[{"type": "image_url", "image_url": {"url": "2.1"}}], name="2.2" ), @@ -586,6 +592,9 @@ def test_convert_to_messages() -> None: ), ToolMessage(content="5.1", name="5.3", tool_call_id="5.2"), SystemMessage(content="6"), + SystemMessage( + content="6.1", additional_kwargs={"__openai_role__": "developer"} + ), HumanMessage( content=[{"type": "image_url", "image_url": {"url": "7.1"}}], name="7.2" ), @@ -603,6 +612,9 @@ def test_convert_to_messages() -> None: ), ToolMessage(content="10.1", tool_call_id="10.2"), SystemMessage(content="11.1"), + SystemMessage( + content="11.2", additional_kwargs={"__openai_role__": "developer"} + ), HumanMessage(content=[{"type": "image_url", "image_url": {"url": "12.1"}}]), AIMessage( content=[ @@ -937,3 +949,12 @@ def test_convert_to_openai_messages_mixed_content_types() -> None: assert isinstance(result[0]["content"][0], dict) assert isinstance(result[0]["content"][1], dict) assert isinstance(result[0]["content"][2], dict) + + +def test_convert_to_openai_messages_developer() -> None: + messages: list = [ + SystemMessage("a", additional_kwargs={"__openai_role__": "developer"}), + {"role": "developer", "content": "a"}, + ] + result = convert_to_openai_messages(messages) + assert result == [{"role": "developer", "content": "a"}] * 2 diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index e312ed249c88e..d866e30fe36e1 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -139,8 +139,17 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: tool_calls=tool_calls, invalid_tool_calls=invalid_tool_calls, ) - elif role == "system": - return SystemMessage(content=_dict.get("content", ""), name=name, id=id_) + elif role in ("system", "developer"): + if role == "developer": + additional_kwargs = {"__openai_role__": role} + else: + additional_kwargs = {} + return SystemMessage( + content=_dict.get("content", ""), + name=name, + id=id_, + additional_kwargs=additional_kwargs, + ) elif role == "function": return FunctionMessage( content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_ @@ -233,7 +242,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: ) message_dict["audio"] = audio elif isinstance(message, SystemMessage): - message_dict["role"] = "system" + message_dict["role"] = message.additional_kwargs.get( + "__openai_role__", "system" + ) elif isinstance(message, FunctionMessage): message_dict["role"] = "function" elif isinstance(message, ToolMessage): @@ -284,8 +295,14 @@ def _convert_delta_to_message_chunk( id=id_, tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] ) - elif role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content, id=id_) + elif role in ("system", "developer") or default_class == SystemMessageChunk: + if role == "developer": + additional_kwargs = {"__openai_role__": "developer"} + else: + additional_kwargs = {} + return SystemMessageChunk( + content=content, id=id_, additional_kwargs=additional_kwargs + ) elif role == "function" or default_class == FunctionMessageChunk: return FunctionMessageChunk(content=content, name=_dict["name"], id=id_) elif role == "tool" or default_class == ToolMessageChunk: diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 1204ccef87c4e..f997ce44d6852 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1097,3 +1097,16 @@ def test_o1_max_tokens() -> None: "how are you" ) assert isinstance(response, AIMessage) + + +def test_developer_message() -> None: + llm = ChatOpenAI(model="o1", max_tokens=10) # type: ignore[call-arg] + response = llm.invoke( + [ + {"role": "developer", "content": "respond in all caps"}, + {"role": "user", "content": "HOW ARE YOU"}, + ] + ) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert response.content.upper() == response.content diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index aea480a748177..6e889b80a562c 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -100,6 +100,16 @@ def test__convert_dict_to_message_system() -> None: assert _convert_message_to_dict(expected_output) == message +def test__convert_dict_to_message_developer() -> None: + message = {"role": "developer", "content": "foo"} + result = _convert_dict_to_message(message) + expected_output = SystemMessage( + content="foo", additional_kwargs={"__openai_role__": "developer"} + ) + assert result == expected_output + assert _convert_message_to_dict(expected_output) == message + + def test__convert_dict_to_message_system_with_name() -> None: message = {"role": "system", "content": "foo", "name": "test"} result = _convert_dict_to_message(message) @@ -850,3 +860,25 @@ class JokeWithEvaluation(TypedDict): self_evaluation: SelfEvaluation llm.with_structured_output(JokeWithEvaluation, method="json_schema") + + +def test__get_request_payload() -> None: + llm = ChatOpenAI(model="gpt-4o-2024-08-06") + messages: list = [ + SystemMessage("hello"), + SystemMessage("bye", additional_kwargs={"__openai_role__": "developer"}), + {"role": "human", "content": "how are you"}, + ] + expected = { + "messages": [ + {"role": "system", "content": "hello"}, + {"role": "developer", "content": "bye"}, + {"role": "user", "content": "how are you"}, + ], + "model": "gpt-4o-2024-08-06", + "stream": False, + "n": 1, + "temperature": 0.7, + } + payload = llm._get_request_payload(messages) + assert payload == expected