Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Refactoring and fix bug. AuthorRole instead of string literals. #6391

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai import PromptExecutionSettings
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextCompletion
from semantic_kernel.contents import ChatHistory
from semantic_kernel.contents import AuthorRole, ChatHistory
from semantic_kernel.functions import KernelArguments
from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig

Expand Down Expand Up @@ -204,7 +204,9 @@ def _check_banned_words(banned_list, actual_list) -> bool:

def _format_output(chat, banned_words) -> None:
print("--- Checking for banned words ---")
chat_bot_ans_words = [word for msg in chat.messages if msg.role == "assistant" for word in msg.content.split()]
chat_bot_ans_words = [
word for msg in chat.messages if msg.role == AuthorRole.ASSISTANT for word in msg.content.split()
]
if _check_banned_words(banned_words, chat_bot_ans_words):
print("None of the banned words were found in the answer")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import reduce

from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.filters.filter_types import FilterTypes
Expand Down Expand Up @@ -39,7 +40,7 @@ async def override_stream(stream):
async for partial in stream:
yield partial
except Exception as e:
yield [StreamingChatMessageContent(author="assistant", content=f"Exception caught: {e}")]
yield [StreamingChatMessageContent(role=AuthorRole.ASSISTANT, content=f"Exception caught: {e}")]

stream = context.result.value
context.result = FunctionResult(function=context.result.function, value=override_stream(stream))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from semantic_kernel.connectors.ai.ollama.ollama_prompt_execution_settings import OllamaChatPromptExecutionSettings
from semantic_kernel.connectors.ai.ollama.utils import AsyncSession
from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase
from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
Expand Down Expand Up @@ -66,7 +67,7 @@ async def get_chat_message_contents(
ChatMessageContent(
inner_content=response_object,
ai_model_id=self.ai_model_id,
role="assistant",
role=AuthorRole.ASSISTANT,
content=response_object.get("message", {"content": None}).get("content", None),
)
]
Expand Down Expand Up @@ -105,7 +106,7 @@ async def get_streaming_chat_message_contents(
break
yield [
StreamingChatMessageContent(
role="assistant",
role=AuthorRole.ASSISTANT,
choice_index=0,
inner_content=body,
ai_model_id=self.ai_model_id,
Expand All @@ -131,7 +132,7 @@ async def get_text_contents(
"""
if not settings.ai_model_id:
settings.ai_model_id = self.ai_model_id
settings.messages = [{"role": "user", "content": prompt}]
settings.messages = [{"role": AuthorRole.USER, "content": prompt}]
settings.stream = False
async with (
AsyncSession(self.session) as session,
Expand Down Expand Up @@ -165,7 +166,7 @@ async def get_streaming_text_contents(
"""
if not settings.ai_model_id:
settings.ai_model_id = self.ai_model_id
settings.messages = [{"role": "user", "content": prompt}]
settings.messages = [{"role": AuthorRole.USER, "content": prompt}]
settings.stream = True
async with (
AsyncSession(self.session) as session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,12 @@ async def get_streaming_chat_message_contents(

def _chat_message_content_to_dict(self, message: "ChatMessageContent") -> dict[str, str | None]:
msg = super()._chat_message_content_to_dict(message)
if message.role == "assistant":
if message.role == AuthorRole.ASSISTANT:
if tool_calls := getattr(message, "tool_calls", None):
msg["tool_calls"] = [tool_call.model_dump() for tool_call in tool_calls]
if function_call := getattr(message, "function_call", None):
msg["function_call"] = function_call.model_dump_json()
if message.role == "tool":
if message.role == AuthorRole.TOOL:
if tool_call_id := getattr(message, "tool_call_id", None):
msg["tool_call_id"] = tool_call_id
if message.metadata and "function" in message.metadata:
Expand Down
5 changes: 3 additions & 2 deletions python/semantic_kernel/contents/function_result_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pydantic import field_validator

from semantic_kernel.contents.author_role import AuthorRole
from semantic_kernel.contents.const import FUNCTION_RESULT_CONTENT_TAG, TEXT_CONTENT_TAG
from semantic_kernel.contents.kernel_content import KernelContent
from semantic_kernel.contents.text_content import TextContent
Expand Down Expand Up @@ -104,8 +105,8 @@ def to_chat_message_content(self, unwrap: bool = False) -> "ChatMessageContent":
from semantic_kernel.contents.chat_message_content import ChatMessageContent

if unwrap:
return ChatMessageContent(role="tool", items=[self.result]) # type: ignore
return ChatMessageContent(role="tool", items=[self]) # type: ignore
return ChatMessageContent(role=AuthorRole.TOOL, items=[self.result]) # type: ignore
return ChatMessageContent(role=AuthorRole.TOOL, items=[self]) # type: ignore

def to_dict(self) -> dict[str, str]:
"""Convert the instance to a dictionary."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletionBase
from semantic_kernel.contents import ChatMessageContent, StreamingChatMessageContent, TextContent
from semantic_kernel.contents import AuthorRole, ChatMessageContent, StreamingChatMessageContent, TextContent
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.exceptions import FunctionCallInvalidArgumentsException
Expand Down Expand Up @@ -64,7 +64,9 @@ async def test_complete_chat(tool_call, kernel: Kernel):
settings.function_call_behavior = None
mock_function_call = MagicMock(spec=FunctionCallContent)
mock_text = MagicMock(spec=TextContent)
mock_message = ChatMessageContent(role="assistant", items=[mock_function_call] if tool_call else [mock_text])
mock_message = ChatMessageContent(
role=AuthorRole.ASSISTANT, items=[mock_function_call] if tool_call else [mock_text]
)
mock_message_content = [mock_message]
arguments = KernelArguments()

Expand Down
4 changes: 2 additions & 2 deletions python/tests/unit/contents/test_chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def test_dump():
)
dump = chat_history.model_dump(exclude_none=True)
assert dump is not None
assert dump["messages"][0]["role"] == "system"
assert dump["messages"][0]["role"] == AuthorRole.SYSTEM
assert dump["messages"][0]["items"][0]["text"] == system_msg
assert dump["messages"][1]["role"] == "user"
assert dump["messages"][1]["role"] == AuthorRole.USER
assert dump["messages"][1]["items"][0]["text"] == "Message"


Expand Down
30 changes: 15 additions & 15 deletions python/tests/unit/contents/test_chat_message_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@


def test_cmc():
message = ChatMessageContent(role="user", content="Hello, world!")
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
assert len(message.items) == 1


def test_cmc_str():
message = ChatMessageContent(role="user", content="Hello, world!")
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
eavanvalkenburg marked this conversation as resolved.
Show resolved Hide resolved
assert str(message) == "Hello, world!"


def test_cmc_full():
message = ChatMessageContent(
role="user",
role=AuthorRole.USER,
name="username",
content="Hello, world!",
inner_content="Hello, world!",
Expand All @@ -42,14 +42,14 @@ def test_cmc_full():


def test_cmc_items():
message = ChatMessageContent(role="user", items=[TextContent(text="Hello, world!")])
message = ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="Hello, world!")])
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
assert len(message.items) == 1


def test_cmc_items_and_content():
message = ChatMessageContent(role="user", content="text", items=[TextContent(text="Hello, world!")])
message = ChatMessageContent(role=AuthorRole.USER, content="text", items=[TextContent(text="Hello, world!")])
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
assert message.items[0].text == "Hello, world!"
Expand All @@ -59,7 +59,7 @@ def test_cmc_items_and_content():

def test_cmc_multiple_items():
message = ChatMessageContent(
role="system",
role=AuthorRole.SYSTEM,
items=[
TextContent(text="Hello, world!"),
TextContent(text="Hello, world!"),
Expand All @@ -71,7 +71,7 @@ def test_cmc_multiple_items():


def test_cmc_content_set():
message = ChatMessageContent(role="user", content="Hello, world!")
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
message.content = "Hello, world to you too!"
Expand All @@ -82,7 +82,7 @@ def test_cmc_content_set():


def test_cmc_content_set_empty():
message = ChatMessageContent(role="user", content="Hello, world!")
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
message.items.pop()
Expand All @@ -92,7 +92,7 @@ def test_cmc_content_set_empty():


def test_cmc_to_element():
message = ChatMessageContent(role="user", content="Hello, world!", name=None)
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!", name=None)
element = message.to_element()
assert element.tag == "message"
assert element.attrib == {"role": "user"}
Expand All @@ -103,13 +103,13 @@ def test_cmc_to_element():


def test_cmc_to_prompt():
message = ChatMessageContent(role="user", content="Hello, world!")
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
prompt = message.to_prompt()
assert prompt == '<message role="user"><text>Hello, world!</text></message>'


def test_cmc_from_element():
element = ChatMessageContent(role="user", content="Hello, world!").to_element()
element = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!").to_element()
message = ChatMessageContent.from_element(element)
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
Expand Down Expand Up @@ -182,22 +182,22 @@ def test_cmc_from_element_content_parse(xml_content, user, text_content, length)


def test_cmc_serialize():
message = ChatMessageContent(role="user", content="Hello, world!")
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
dumped = message.model_dump()
assert dumped["role"] == "user"
assert dumped["role"] == AuthorRole.USER
assert dumped["items"][0]["text"] == "Hello, world!"


def test_cmc_to_dict():
message = ChatMessageContent(role="user", content="Hello, world!")
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.to_dict() == {
"role": "user",
"content": "Hello, world!",
}


def test_cmc_to_dict_keys():
message = ChatMessageContent(role="user", content="Hello, world!")
message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.to_dict(role_key="author", content_key="text") == {
"author": "user",
"text": "Hello, world!",
Expand Down
Loading
Loading