Skip to content

Commit

Permalink
Added normalize_name and assert_valid name
Browse files Browse the repository at this point in the history
  • Loading branch information
rohanthacker committed Dec 16, 2024
1 parent 462fa5f commit 45925b2
Showing 1 changed file with 26 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import re
from asyncio import Task
from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast
from inspect import getfullargspec
Expand Down Expand Up @@ -88,7 +89,7 @@ def _system_message_to_azure(message: SystemMessage) -> AzureSystemMessage:


def _user_message_to_azure(message: UserMessage) -> AzureUserMessage:
# assert_valid_name(message.source)
assert_valid_name(message.source)
if isinstance(message.content, str):
return AzureUserMessage(content=message.content)
else:
Expand All @@ -106,7 +107,7 @@ def _user_message_to_azure(message: UserMessage) -> AzureUserMessage:


def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMessage:
# assert_valid_name(message.source)
assert_valid_name(message.source)
if isinstance(message.content, list):
return AzureAssistantMessage(
tool_calls=[_func_call_to_azure(x) for x in message.content],
Expand All @@ -130,7 +131,28 @@ def to_azure_message(message: LLMMessage):
return _tool_message_to_azure(message)


# TODO: Add Support for Github Models
def normalize_name(name: str) -> str:
"""
LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_".
Prefer _assert_valid_name for validating user configuration or input
"""
return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64]


def assert_valid_name(name: str) -> str:
"""
Ensure that configured names are valid, raises ValueError if not.
For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API.
"""
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.")
if len(name) > 64:
raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.")
return name


class AzureAIChatCompletionClient(ChatCompletionClient):
def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]):
if "endpoint" not in kwargs:
Expand Down Expand Up @@ -222,7 +244,7 @@ async def create(
FunctionCall(
id=x.id,
arguments=x.function.arguments,
name=x.function.name,
name=normalize_name(x.function.name),
)
for x in choice.message.tool_calls
]
Expand Down

0 comments on commit 45925b2

Please sign in to comment.