Skip to content

Commit 01644fa

Browse files
refactor: clean up message extraction logic
Signed-off-by: Patrick Chin <[email protected]>
1 parent 9051e54 commit 01644fa

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

packages/nvidia_nat_weave/src/nat/plugins/weave/weave_exporter.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717
from collections.abc import Generator
1818
from contextlib import contextmanager
19+
from typing import Any
1920

2021
from nat.data_models.intermediate_step import IntermediateStep
2122
from nat.data_models.span import Span
@@ -186,6 +187,41 @@ def _create_weave_call(self, step: IntermediateStep, span: Span) -> Call:
186187

187188
return call
188189

190+
def _extract_output_message(self, output_data: Any, outputs: dict[str, Any]) -> None:
191+
"""
192+
Extract message content from various response formats and add to outputs dictionary.
193+
194+
Args:
195+
output_data: The raw output data from the response
196+
outputs: Dictionary to populate with extracted message content
197+
"""
198+
# Handle direct "choices" attribute (non-streaming: output.choices[0].message.content)
199+
choices = getattr(output_data, 'choices', None)
200+
if choices:
201+
outputs["output_message"] = choices[0].message.content
202+
return
203+
204+
# Handle list-based output (streaming or websocket) – content may be in the following formats:
205+
# output[0].choices[0].message.content
206+
# output[0].choices[0].delta.content
207+
# output[0].value
208+
if not isinstance(output_data, list) or not output_data:
209+
return
210+
211+
choices = getattr(output_data[0], 'choices', None)
212+
if choices:
213+
message = getattr(choices[0], 'message', None)
214+
delta = getattr(choices[0], 'delta', None)
215+
216+
if message:
217+
outputs["output_message"] = getattr(message, 'content', None)
218+
elif delta:
219+
outputs["output_preview"] = getattr(delta, 'content', None)
220+
else:
221+
value = getattr(output_data[0], 'value', None)
222+
if value:
223+
outputs["output_preview"] = value
224+
189225
def _finish_weave_call(self, step: IntermediateStep) -> None:
190226
"""
191227
Finish a previously created Weave call.
@@ -206,33 +242,7 @@ def _finish_weave_call(self, step: IntermediateStep) -> None:
206242
try:
207243
# Add the output to the Weave call
208244
outputs["output"] = step.payload.data.output
209-
210-
# Extract message content based on response format
211-
# Non-streaming: output.choices[0].message.content
212-
choices = getattr(step.payload.data.output, 'choices', None)
213-
if choices:
214-
outputs["output_message"] = choices[0].message.content
215-
# List format (websocket/streaming):
216-
# output[0].choices[0].message.content or
217-
# output[0].choices[0].delta.content
218-
elif isinstance(step.payload.data.output, list) and len(step.payload.data.output) > 0:
219-
first_item = step.payload.data.output[0]
220-
choices = getattr(first_item, 'choices', None)
221-
if choices and len(choices) > 0:
222-
# Try websocket format: choices[0].message.content
223-
message = getattr(choices[0], 'message', None)
224-
if message:
225-
outputs["output_message"] = getattr(message, 'content', None)
226-
# Try streaming format: choices[0].delta.content
227-
else:
228-
delta = getattr(choices[0], 'delta', None)
229-
if delta:
230-
outputs["output_preview"] = getattr(delta, 'content', None)
231-
# Generate endpoint: output[0].value
232-
else:
233-
value = getattr(first_item, 'value', None)
234-
if value:
235-
outputs["output_preview"] = value
245+
self._extract_output_message(step.payload.data.output, outputs)
236246
except Exception:
237247
# If serialization fails, use string representation
238248
outputs["output"] = str(step.payload.data.output)

0 commit comments

Comments
 (0)