Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nat.cli.register_workflow import register_function
from nat.data_models.api_server import ChatRequest
from nat.data_models.api_server import ChatResponse
from nat.data_models.api_server import Usage
from nat.data_models.component_ref import FunctionRef
from nat.data_models.function import FunctionBaseConfig
from nat.data_models.interactive import HumanPromptText
Expand Down Expand Up @@ -161,7 +162,11 @@ async def handle_recursion_error(input_message: ChatRequest) -> ChatResponse:

# If user doesn't approve, return error message
if not selected_option:
return ChatResponse.from_string("I seem to be having a problem.")
error_msg = "I seem to be having a problem."

# Create usage statistics for error response
usage = Usage(prompt_tokens=None, completion_tokens=None, total_tokens=None)
return ChatResponse.from_string(error_msg, usage=usage)

# If we exhausted all retries, return the last response
return response
Expand Down Expand Up @@ -202,11 +207,19 @@ async def _response_fn(input_message: ChatRequest) -> ChatResponse:
return await handle_recursion_error(input_message)

# User declined - return error message
return ChatResponse.from_string("I seem to be having a problem.")
error_msg = "I seem to be having a problem."

# Create usage statistics for error response
usage = Usage(prompt_tokens=None, completion_tokens=None, total_tokens=None)
return ChatResponse.from_string(error_msg, usage=usage)

except Exception:
# Handle any other unexpected exceptions
return ChatResponse.from_string("I seem to be having a problem.")
error_msg = "I seem to be having a problem."

# Create usage statistics for error response
usage = Usage(prompt_tokens=None, completion_tokens=None, total_tokens=None)
return ChatResponse.from_string(error_msg, usage=usage)

yield FunctionInfo.from_fn(_response_fn, description=config.description)

Expand Down
10 changes: 9 additions & 1 deletion packages/nvidia_nat_test/src/nat/test/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nat.data_models.api_server import ChatRequest
from nat.data_models.api_server import ChatResponse
from nat.data_models.api_server import ChatResponseChunk
from nat.data_models.api_server import Usage
from nat.data_models.function import FunctionBaseConfig


Expand All @@ -35,7 +36,14 @@ async def inner(message: str) -> str:
return message

async def inner_oai(message: ChatRequest) -> ChatResponse:
return ChatResponse.from_string(message.messages[0].content)
content = message.messages[0].content

# Create usage statistics for the response
prompt_tokens = sum(len(str(msg.content).split()) for msg in message.messages)
completion_tokens = len(content.split()) if content else 0
total_tokens = prompt_tokens + completion_tokens
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
return ChatResponse.from_string(content, usage=usage)

if (config.use_openai_api):
yield inner_oai
Expand Down
10 changes: 9 additions & 1 deletion src/nat/agent/react_agent/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nat.data_models.agent import AgentBaseConfig
from nat.data_models.api_server import ChatRequest
from nat.data_models.api_server import ChatResponse
from nat.data_models.api_server import Usage
from nat.data_models.component_ref import FunctionGroupRef
from nat.data_models.component_ref import FunctionRef
from nat.data_models.optimizable import OptimizableField
Expand Down Expand Up @@ -149,7 +150,14 @@ async def _response_fn(input_message: ChatRequest) -> ChatResponse:
# get and return the output from the state
state = ReActGraphState(**state)
output_message = state.messages[-1]
return ChatResponse.from_string(str(output_message.content))
content = str(output_message.content)

# Create usage statistics for the response
prompt_tokens = sum(len(str(msg.content).split()) for msg in input_message.messages)
completion_tokens = len(content.split()) if content else 0
total_tokens = prompt_tokens + completion_tokens
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
return ChatResponse.from_string(content, usage=usage)

except Exception as ex:
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
Expand Down
9 changes: 8 additions & 1 deletion src/nat/agent/rewoo_agent/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nat.data_models.agent import AgentBaseConfig
from nat.data_models.api_server import ChatRequest
from nat.data_models.api_server import ChatResponse
from nat.data_models.api_server import Usage
from nat.data_models.component_ref import FunctionGroupRef
from nat.data_models.component_ref import FunctionRef
from nat.utils.type_converter import GlobalTypeConverter
Expand Down Expand Up @@ -157,7 +158,13 @@ async def _response_fn(input_message: ChatRequest) -> ChatResponse:
# Ensure output_message is a string
if isinstance(output_message, list | dict):
output_message = str(output_message)
return ChatResponse.from_string(output_message)

# Create usage statistics for the response
prompt_tokens = sum(len(str(msg.content).split()) for msg in input_message.messages)
completion_tokens = len(output_message.split()) if output_message else 0
total_tokens = prompt_tokens + completion_tokens
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
return ChatResponse.from_string(output_message, usage=usage)

except Exception as ex:
logger.exception("ReWOO Agent failed with exception: %s", ex)
Expand Down
119 changes: 62 additions & 57 deletions src/nat/data_models/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pydantic import conlist
from pydantic import field_serializer
from pydantic import field_validator
from pydantic import model_serializer
from pydantic_core.core_schema import ValidationInfo

from nat.data_models.interactive import HumanPrompt
Expand All @@ -36,6 +37,11 @@
FINISH_REASONS = frozenset({'stop', 'length', 'tool_calls', 'content_filter', 'function_call'})


class UserMessageContentRoleType(str, Enum):
USER = "user"
ASSISTANT = "assistant"


class Request(BaseModel):
"""
Request is a data model that represents HTTP request attributes.
Expand Down Expand Up @@ -108,7 +114,7 @@ class Security(BaseModel):

class Message(BaseModel):
content: str | list[UserContent]
role: str
role: UserMessageContentRoleType


class ChatRequest(BaseModel):
Expand Down Expand Up @@ -164,7 +170,7 @@ def from_string(data: str,
max_tokens: int | None = None,
top_p: float | None = None) -> "ChatRequest":

return ChatRequest(messages=[Message(content=data, role="user")],
return ChatRequest(messages=[Message(content=data, role=UserMessageContentRoleType.USER)],
model=model,
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -178,7 +184,7 @@ def from_content(content: list[UserContent],
max_tokens: int | None = None,
top_p: float | None = None) -> "ChatRequest":

return ChatRequest(messages=[Message(content=content, role="user")],
return ChatRequest(messages=[Message(content=content, role=UserMessageContentRoleType.USER)],
model=model,
temperature=temperature,
max_tokens=max_tokens,
Expand All @@ -187,29 +193,40 @@ def from_content(content: list[UserContent],

class ChoiceMessage(BaseModel):
content: str | None = None
role: str | None = None
role: UserMessageContentRoleType | None = None


class ChoiceDelta(BaseModel):
"""Delta object for streaming responses (OpenAI-compatible)"""
content: str | None = None
role: str | None = None
role: UserMessageContentRoleType | None = None


class Choice(BaseModel):
class ChoiceBase(BaseModel):
"""Base choice model with common fields for both streaming and non-streaming responses"""
model_config = ConfigDict(extra="allow")

message: ChoiceMessage | None = None
delta: ChoiceDelta | None = None
finish_reason: typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None = None
index: int
# logprobs: ChoiceLogprobs | None = None


class ChatResponseChoice(ChoiceBase):
"""Choice model for non-streaming responses - contains message field"""
message: ChoiceMessage


class ChatResponseChunkChoice(ChoiceBase):
"""Choice model for streaming responses - contains delta field"""
delta: ChoiceDelta


# Backward compatibility alias
Choice = ChatResponseChoice


class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens: int | None = None
completion_tokens: int | None = None
total_tokens: int | None = None


class ResponseSerializable(abc.ABC):
Expand Down Expand Up @@ -245,10 +262,10 @@ class ChatResponse(ResponseBaseModelOutput):
model_config = ConfigDict(extra="allow")
id: str
object: str = "chat.completion"
model: str = ""
model: str = "unknown-model"
created: datetime.datetime
choices: list[Choice]
usage: Usage | None = None
choices: list[ChatResponseChoice]
usage: Usage # Required for non-streaming responses per OpenAI spec
system_fingerprint: str | None = None
service_tier: typing.Literal["scale", "default"] | None = None

Expand All @@ -264,22 +281,27 @@ def from_string(data: str,
object_: str | None = None,
model: str | None = None,
created: datetime.datetime | None = None,
usage: Usage | None = None) -> "ChatResponse":
usage: Usage) -> "ChatResponse":

if id_ is None:
id_ = str(uuid.uuid4())
if object_ is None:
object_ = "chat.completion"
if model is None:
model = ""
model = "unknown-model"
if created is None:
created = datetime.datetime.now(datetime.UTC)

return ChatResponse(id=id_,
object=object_,
model=model,
created=created,
choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")],
choices=[
ChatResponseChoice(index=0,
message=ChoiceMessage(content=data,
role=UserMessageContentRoleType.ASSISTANT),
finish_reason="stop")
],
usage=usage)


Expand All @@ -293,9 +315,9 @@ class ChatResponseChunk(ResponseBaseModelOutput):
model_config = ConfigDict(extra="allow")

id: str
choices: list[Choice]
choices: list[ChatResponseChunkChoice]
created: datetime.datetime
model: str = ""
model: str = "unknown-model"
object: str = "chat.completion.chunk"
system_fingerprint: str | None = None
service_tier: typing.Literal["scale", "default"] | None = None
Expand All @@ -319,12 +341,18 @@ def from_string(data: str,
if created is None:
created = datetime.datetime.now(datetime.UTC)
if model is None:
model = ""
model = "unknown-model"
if object_ is None:
object_ = "chat.completion.chunk"

return ChatResponseChunk(id=id_,
choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")],
choices=[
ChatResponseChunkChoice(index=0,
delta=ChoiceDelta(
content=data,
role=UserMessageContentRoleType.ASSISTANT),
finish_reason="stop")
],
created=created,
model=model,
object=object_)
Expand All @@ -335,7 +363,7 @@ def create_streaming_chunk(content: str,
id_: str | None = None,
created: datetime.datetime | None = None,
model: str | None = None,
role: str | None = None,
role: UserMessageContentRoleType | None = None,
finish_reason: str | None = None,
usage: Usage | None = None,
system_fingerprint: str | None = None) -> "ChatResponseChunk":
Expand All @@ -345,15 +373,22 @@ def create_streaming_chunk(content: str,
if created is None:
created = datetime.datetime.now(datetime.UTC)
if model is None:
model = ""
model = "unknown-model"

delta = ChoiceDelta(content=content, role=role) if content is not None or role is not None else ChoiceDelta()

final_finish_reason = finish_reason if finish_reason in FINISH_REASONS else None

return ChatResponseChunk(
id=id_,
choices=[Choice(index=0, message=None, delta=delta, finish_reason=final_finish_reason)],
choices=[
ChatResponseChunkChoice(
index=0,
delta=delta,
finish_reason=typing.cast(
typing.Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] | None,
final_finish_reason))
],
created=created,
model=model,
object="chat.completion.chunk",
Expand Down Expand Up @@ -398,11 +433,6 @@ class GenerateResponse(BaseModel):
value: str | None = "default"


class UserMessageContentRoleType(str, Enum):
USER = "user"
ASSISTANT = "assistant"


class WebSocketMessageType(str, Enum):
"""
WebSocketMessageType is an Enum that represents WebSocket Message types.
Expand Down Expand Up @@ -622,7 +652,7 @@ def _nat_chat_request_to_string(data: ChatRequest) -> str:


def _string_to_nat_chat_request(data: str) -> ChatRequest:
return ChatRequest.from_string(data, model="")
return ChatRequest.from_string(data, model="unknown-model")


GlobalTypeConverter.register_converter(_string_to_nat_chat_request)
Expand Down Expand Up @@ -654,22 +684,12 @@ def _string_to_nat_chat_response(data: str) -> ChatResponse:
GlobalTypeConverter.register_converter(_string_to_nat_chat_response)


def _chat_response_to_chat_response_chunk(data: ChatResponse) -> ChatResponseChunk:
# Preserve original message structure for backward compatibility
return ChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model)


GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk)


# ======== ChatResponseChunk Converters ========
def _chat_response_chunk_to_string(data: ChatResponseChunk) -> str:
if data.choices and len(data.choices) > 0:
choice = data.choices[0]
if choice.delta and choice.delta.content:
return choice.delta.content
if choice.message and choice.message.content:
return choice.message.content
return ""


Expand All @@ -685,21 +705,6 @@ def _string_to_nat_chat_response_chunk(data: str) -> ChatResponseChunk:

GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk)


# ======== AINodeMessageChunk Converters ========
def _ai_message_chunk_to_nat_chat_response_chunk(data) -> ChatResponseChunk:
'''Converts LangChain/LangGraph AINodeMessageChunk to ChatResponseChunk'''
content = ""
if hasattr(data, 'content') and data.content is not None:
content = str(data.content)
elif hasattr(data, 'text') and data.text is not None:
content = str(data.text)
elif hasattr(data, 'message') and data.message is not None:
content = str(data.message)

return ChatResponseChunk.create_streaming_chunk(content=content, role="assistant", finish_reason=None)


# Compatibility aliases with previous releases
AIQChatRequest = ChatRequest
AIQChoiceMessage = ChoiceMessage
Expand Down
Loading
Loading