diff --git a/sentry_sdk/integrations/langgraph.py b/sentry_sdk/integrations/langgraph.py index 5bb0e0fd08..5464b2daef 100644 --- a/sentry_sdk/integrations/langgraph.py +++ b/sentry_sdk/integrations/langgraph.py @@ -62,7 +62,13 @@ def _normalize_langgraph_message(message): parsed = {"role": getattr(message, "type", None), "content": message.content} - for attr in ["name", "tool_calls", "function_call", "tool_call_id"]: + for attr in [ + "name", + "tool_calls", + "function_call", + "tool_call_id", + "response_metadata", + ]: if hasattr(message, attr): value = getattr(message, attr) if value is not None: @@ -311,14 +317,51 @@ def _extract_tool_calls(messages): return tool_calls if tool_calls else None +def _set_usage_data(span, messages): + # type: (sentry_sdk.tracing.Span, Any) -> None + input_tokens = 0 + output_tokens = 0 + total_tokens = 0 + + for message in messages: + response_metadata = message.get("response_metadata") + if response_metadata is None: + continue + + token_usage = response_metadata.get("token_usage") + if not token_usage: + continue + + input_tokens += int(token_usage.get("prompt_tokens", 0)) + output_tokens += int(token_usage.get("completion_tokens", 0)) + total_tokens += int(token_usage.get("total_tokens", 0)) + + if input_tokens > 0: + span.set_data(SPANDATA.GEN_AI_USAGE_INPUT_TOKENS, input_tokens) + + if output_tokens > 0: + span.set_data(SPANDATA.GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens) + + if total_tokens > 0: + span.set_data( + SPANDATA.GEN_AI_USAGE_TOTAL_TOKENS, + total_tokens, + ) + + def _set_response_attributes(span, input_messages, result, integration): # type: (Any, Optional[List[Any]], Any, LanggraphIntegration) -> None - if not (should_send_default_pii() and integration.include_prompts): - return - parsed_response_messages = _parse_langgraph_messages(result) new_messages = _get_new_messages(input_messages, parsed_response_messages) + if new_messages is None: + return + + _set_usage_data(span, new_messages) + + if not (should_send_default_pii() and integration.include_prompts): + return + llm_response_text = _extract_llm_response_text(new_messages) if llm_response_text: set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, llm_response_text) diff --git a/tests/integrations/langgraph/test_langgraph.py b/tests/integrations/langgraph/test_langgraph.py index df574dd2c3..1f6c27cd62 100644 --- a/tests/integrations/langgraph/test_langgraph.py +++ b/tests/integrations/langgraph/test_langgraph.py @@ -96,6 +96,7 @@ def __init__( function_call=None, role=None, type=None, + response_metadata=None, ): self.content = content self.name = name @@ -108,6 +109,7 @@ def __init__( self.type = name else: self.type = type + self.response_metadata = response_metadata class MockPregelInstance: @@ -509,6 +511,326 @@ def original_invoke(self, *args, **kwargs): assert SPANDATA.GEN_AI_AGENT_NAME not in invoke_span.get("data", {}) +def test_pregel_invoke_span_includes_usage_data(sentry_init, capture_events): + """ + Test that invoke_agent spans include aggregated usage data from context_wrapper. + This verifies the new functionality added to track token usage in invoke_agent spans. + """ + sentry_init( + integrations=[LanggraphIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + test_state = { + "messages": [ + MockMessage("Hello, can you help me?", name="user"), + MockMessage("Of course! How can I assist you?", name="assistant"), + ] + } + + pregel = MockPregelInstance("test_graph") + + expected_assistant_response = "I'll help you with that task!" + expected_tool_calls = [ + { + "id": "call_test_123", + "type": "function", + "function": {"name": "search_tool", "arguments": '{"query": "help"}'}, + } + ] + + def original_invoke(self, *args, **kwargs): + input_messages = args[0].get("messages", []) + new_messages = input_messages + [ + MockMessage( + content=expected_assistant_response, + name="assistant", + tool_calls=expected_tool_calls, + response_metadata={ + "token_usage": { + "total_tokens": 30, + "prompt_tokens": 10, + "completion_tokens": 20, + }, + "model_name": "gpt-4.1-2025-04-14", + }, + ) + ] + return {"messages": new_messages} + + with start_transaction(): + wrapped_invoke = _wrap_pregel_invoke(original_invoke) + result = wrapped_invoke(pregel, test_state) + + assert result is not None + + tx = events[0] + assert tx["type"] == "transaction" + + invoke_spans = [ + span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT + ] + assert len(invoke_spans) == 1 + + invoke_agent_span = invoke_spans[0] + + # Verify invoke_agent span has usage data + assert invoke_agent_span["description"] == "invoke_agent test_graph" + assert "gen_ai.usage.input_tokens" in invoke_agent_span["data"] + assert "gen_ai.usage.output_tokens" in invoke_agent_span["data"] + assert "gen_ai.usage.total_tokens" in invoke_agent_span["data"] + + # The usage should match the mock_usage values (aggregated across all calls) + assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 10 + assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20 + assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 30 + + +def test_pregel_ainvoke_span_includes_usage_data(sentry_init, capture_events): + """ + Test that invoke_agent spans include aggregated usage data from context_wrapper. + This verifies the new functionality added to track token usage in invoke_agent spans. + """ + sentry_init( + integrations=[LanggraphIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + test_state = { + "messages": [ + MockMessage("Hello, can you help me?", name="user"), + MockMessage("Of course! How can I assist you?", name="assistant"), + ] + } + + pregel = MockPregelInstance("test_graph") + + expected_assistant_response = "I'll help you with that task!" + expected_tool_calls = [ + { + "id": "call_test_123", + "type": "function", + "function": {"name": "search_tool", "arguments": '{"query": "help"}'}, + } + ] + + async def original_ainvoke(self, *args, **kwargs): + input_messages = args[0].get("messages", []) + new_messages = input_messages + [ + MockMessage( + content=expected_assistant_response, + name="assistant", + tool_calls=expected_tool_calls, + response_metadata={ + "token_usage": { + "total_tokens": 30, + "prompt_tokens": 10, + "completion_tokens": 20, + }, + "model_name": "gpt-4.1-2025-04-14", + }, + ) + ] + return {"messages": new_messages} + + async def run_test(): + with start_transaction(): + wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke) + result = await wrapped_ainvoke(pregel, test_state) + return result + + result = asyncio.run(run_test()) + assert result is not None + + tx = events[0] + assert tx["type"] == "transaction" + + invoke_spans = [ + span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT + ] + assert len(invoke_spans) == 1 + + invoke_agent_span = invoke_spans[0] + + # Verify invoke_agent span has usage data + assert invoke_agent_span["description"] == "invoke_agent test_graph" + assert "gen_ai.usage.input_tokens" in invoke_agent_span["data"] + assert "gen_ai.usage.output_tokens" in invoke_agent_span["data"] + assert "gen_ai.usage.total_tokens" in invoke_agent_span["data"] + + # The usage should match the mock_usage values (aggregated across all calls) + assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 10 + assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20 + assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 30 + + +def test_pregel_invoke_multiple_llm_calls_aggregate_usage(sentry_init, capture_events): + """ + Test that invoke_agent spans show aggregated usage across multiple LLM calls + (e.g., when tools are used and multiple API calls are made). + """ + sentry_init( + integrations=[LanggraphIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + test_state = { + "messages": [ + MockMessage("Hello, can you help me?", name="user"), + MockMessage("Of course! How can I assist you?", name="assistant"), + ] + } + + pregel = MockPregelInstance("test_graph") + + expected_assistant_response = "I'll help you with that task!" + expected_tool_calls = [ + { + "id": "call_test_123", + "type": "function", + "function": {"name": "search_tool", "arguments": '{"query": "help"}'}, + } + ] + + def original_invoke(self, *args, **kwargs): + input_messages = args[0].get("messages", []) + new_messages = input_messages + [ + MockMessage( + content=expected_assistant_response, + name="assistant", + tool_calls=expected_tool_calls, + response_metadata={ + "token_usage": { + "total_tokens": 15, + "prompt_tokens": 10, + "completion_tokens": 5, + }, + }, + ), + MockMessage( + content=expected_assistant_response, + name="assistant", + tool_calls=expected_tool_calls, + response_metadata={ + "token_usage": { + "total_tokens": 35, + "prompt_tokens": 20, + "completion_tokens": 15, + }, + }, + ), + ] + return {"messages": new_messages} + + with start_transaction(): + wrapped_invoke = _wrap_pregel_invoke(original_invoke) + result = wrapped_invoke(pregel, test_state) + + assert result is not None + + tx = events[0] + assert tx["type"] == "transaction" + + invoke_spans = [ + span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT + ] + assert len(invoke_spans) == 1 + invoke_agent_span = invoke_spans[0] + + # Verify invoke_agent span has aggregated usage from both API calls + # Total: 10 + 20 = 30 input tokens, 5 + 15 = 20 output tokens, 15 + 35 = 50 total + assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 30 + assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20 + assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 50 + + +def test_pregel_ainvoke_multiple_llm_calls_aggregate_usage(sentry_init, capture_events): + """ + Test that invoke_agent spans show aggregated usage across multiple LLM calls + (e.g., when tools are used and multiple API calls are made). + """ + sentry_init( + integrations=[LanggraphIntegration()], + traces_sample_rate=1.0, + ) + events = capture_events() + + test_state = { + "messages": [ + MockMessage("Hello, can you help me?", name="user"), + MockMessage("Of course! How can I assist you?", name="assistant"), + ] + } + + pregel = MockPregelInstance("test_graph") + + expected_assistant_response = "I'll help you with that task!" + expected_tool_calls = [ + { + "id": "call_test_123", + "type": "function", + "function": {"name": "search_tool", "arguments": '{"query": "help"}'}, + } + ] + + async def original_ainvoke(self, *args, **kwargs): + input_messages = args[0].get("messages", []) + new_messages = input_messages + [ + MockMessage( + content=expected_assistant_response, + name="assistant", + tool_calls=expected_tool_calls, + response_metadata={ + "token_usage": { + "total_tokens": 15, + "prompt_tokens": 10, + "completion_tokens": 5, + }, + }, + ), + MockMessage( + content=expected_assistant_response, + name="assistant", + tool_calls=expected_tool_calls, + response_metadata={ + "token_usage": { + "total_tokens": 35, + "prompt_tokens": 20, + "completion_tokens": 15, + }, + }, + ), + ] + return {"messages": new_messages} + + async def run_test(): + with start_transaction(): + wrapped_ainvoke = _wrap_pregel_ainvoke(original_ainvoke) + result = await wrapped_ainvoke(pregel, test_state) + return result + + result = asyncio.run(run_test()) + assert result is not None + + tx = events[0] + assert tx["type"] == "transaction" + + invoke_spans = [ + span for span in tx["spans"] if span["op"] == OP.GEN_AI_INVOKE_AGENT + ] + assert len(invoke_spans) == 1 + invoke_agent_span = invoke_spans[0] + + # Verify invoke_agent span has aggregated usage from both API calls + # Total: 10 + 20 = 30 input tokens, 5 + 15 = 20 output tokens, 15 + 35 = 50 total + assert invoke_agent_span["data"]["gen_ai.usage.input_tokens"] == 30 + assert invoke_agent_span["data"]["gen_ai.usage.output_tokens"] == 20 + assert invoke_agent_span["data"]["gen_ai.usage.total_tokens"] == 50 + + def test_complex_message_parsing(): """Test message parsing with complex message structures.""" messages = [