Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import logging
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any

from nat.data_models.intermediate_step import IntermediateStep
from nat.data_models.span import Span
from nat.observability.exporter.base_exporter import IsolatedAttribute
from nat.observability.exporter.span_exporter import SpanExporter
from nat.utils.log_utils import LogFilter
from nat.utils.string_utils import truncate_string
from nat.utils.type_utils import override
from weave.trace.context import weave_client_context
from weave.trace.context.call_context import get_current_call
Expand Down Expand Up @@ -152,6 +154,7 @@ def _create_weave_call(self, step: IntermediateStep, span: Span) -> Call:
try:
# Add the input to the Weave call
inputs["input"] = step.payload.data.input
self._extract_input_message(step.payload.data.input, inputs)
except Exception:
# If serialization fails, use string representation
inputs["input"] = str(step.payload.data.input)
Expand All @@ -176,6 +179,74 @@ def _create_weave_call(self, step: IntermediateStep, span: Span) -> Call:

return call

def _extract_input_message(self, input_data: Any, inputs: dict[str, Any]) -> None:
"""
Extract message content from input data and add to inputs dictionary.
Also handles websocket mode where message is located at messages[0].content[0].text.

Args:
input_data: The raw input data from the request
inputs: Dictionary to populate with extracted message content
"""
# Extract message content if input has messages attribute
messages = getattr(input_data, 'messages', [])
if messages:
content = messages[0].content
if isinstance(content, list) and content:
inputs["input_message"] = getattr(content[0], 'text', content[0])
else:
inputs["input_message"] = content

def _extract_output_message(self, output_data: Any, outputs: dict[str, Any]) -> None:
"""
Extract message content from various response formats and add a preview to the outputs dictionary.
Supported output formats for message content include:
- output.choices[0].message.content /chat endpoint
- output.value /generate endpoint
- output[0].choices[0].message.content chat WS schema
- output[0].choices[0].delta.content chat_stream WS schema, /chat/stream endpoint
- output[0].value generate & generate_stream WS schema, /generate/stream endpoint

Args:
output_data: The raw output data from the response
outputs: Dictionary to populate with extracted message content.
No data is added to the outputs dictionary if the output format is not supported.
"""
# Handle choices-keyed output object for /chat completion endpoint
choices = getattr(output_data, 'choices', None)
if choices:
outputs["output_message"] = truncate_string(choices[0].message.content)
return

# Handle value-keyed output object for union types common for /generate completion endpoint
value = getattr(output_data, 'value', None)
if value:
outputs["output_message"] = truncate_string(value)
return

# Handle list-based outputs (streaming or websocket)
if not isinstance(output_data, list) or not output_data:
return

choices = getattr(output_data[0], 'choices', None)
if choices:
# chat websocket schema
message = getattr(choices[0], 'message', None)
if message:
outputs["output_message"] = truncate_string(getattr(message, 'content', None))
return

# chat_stream websocket schema and /chat/stream completion endpoint
delta = getattr(choices[0], 'delta', None)
if delta:
outputs["output_preview"] = truncate_string(getattr(delta, 'content', None))
return

# generate & generate_stream websocket schema, and /generate/stream completion endpoint
value = getattr(output_data[0], 'value', None)
if value:
outputs["output_preview"] = truncate_string(str(value))

def _finish_weave_call(self, step: IntermediateStep) -> None:
"""
Finish a previously created Weave call.
Expand All @@ -196,6 +267,7 @@ def _finish_weave_call(self, step: IntermediateStep) -> None:
try:
# Add the output to the Weave call
outputs["output"] = step.payload.data.output
self._extract_output_message(step.payload.data.output, outputs)
except Exception:
# If serialization fails, use string representation
outputs["output"] = str(step.payload.data.output)
Expand Down
14 changes: 11 additions & 3 deletions src/nat/runtime/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ async def result(self, to_type: type | None = None):
IntermediateStepPayload(UUID=workflow_step_uuid,
event_type=IntermediateStepType.WORKFLOW_START,
name=workflow_name,
metadata=start_metadata))
metadata=start_metadata,
data=StreamEventData(input=self._input_message)))

result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type) # type: ignore

Expand Down Expand Up @@ -249,9 +250,15 @@ async def result_stream(self, to_type: type | None = None):
IntermediateStepPayload(UUID=workflow_step_uuid,
event_type=IntermediateStepType.WORKFLOW_START,
name=workflow_name,
metadata=start_metadata))
metadata=start_metadata,
data=StreamEventData(input=self._input_message)))

# Collect preview of streaming results for the WORKFLOW_END event
output_preview = []

async for m in self._entry_fn.astream(self._input_message, to_type=to_type): # type: ignore
if len(output_preview) < 50:
output_preview.append(m)
yield m

# Emit WORKFLOW_END
Expand All @@ -265,7 +272,8 @@ async def result_stream(self, to_type: type | None = None):
IntermediateStepPayload(UUID=workflow_step_uuid,
event_type=IntermediateStepType.WORKFLOW_END,
name=workflow_name,
metadata=end_metadata))
metadata=end_metadata,
data=StreamEventData(output=output_preview)))
self._state = RunnerState.COMPLETED

# Close the intermediate stream
Expand Down
16 changes: 16 additions & 0 deletions src/nat/utils/string_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,19 @@ def convert_to_str(value: Any) -> str:
return str(value)
else:
raise ValueError(f"Unsupported type for conversion to string: {type(value)}")


def truncate_string(text: str | None, max_length: int = 100) -> str | None:
"""
Truncate a string to a maximum length, adding ellipsis if truncated.

Args:
text: The text to truncate (can be None)
max_length: Maximum allowed length (default: 100)

Returns:
The truncated text with ellipsis if needed, or None if input was None
"""
if not text or len(text) <= max_length:
return text
return text[:max_length - 3] + "..."
Loading