Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# - `special_sanity`: a suite of quick sanity tests
# - `special_standalone`: a set of test that are designed to run in dedicated environments

# Accelerators for tests
# Accelerators for tests
# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.
# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.

Expand Down Expand Up @@ -128,4 +128,8 @@ jobs:
run: |
pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2
pytest -svvv tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py
- name: Running LangGraph multi-turn rollout tests on 8 L20 GPUs
run: |
pip3 install -e .[langgraph]
pytest -svvv tests/workers/rollout/rollout_vllm/test_vllm_langgraph_chat_scheduler.py
# Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests
273 changes: 273 additions & 0 deletions examples/grpo_trainer/config/reflection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 JetBrains s.r.o. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
LangGraph implementation with ReAct agent and reflection step.

This module creates a graph where:
1. A ReAct agent performs initial reasoning and action
2. A deterministic validation message is appended asking the agent to validate its work
3. The agent can go through multiple rounds of self-reflection
"""

import json
from typing import List, TypedDict

import requests
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.tools import BaseTool, tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import create_react_agent


class GraphState(TypedDict):
"""State for the reflection graph."""

messages: List[BaseMessage]
current_round: int
max_rounds: int
needs_validation: bool


def create_sandboxfusion_tool(sandboxfusion_api: str) -> BaseTool:
"""Create sample tools for the ReAct agent with optional SandboxFusion API."""

@tool
def code_execution_tool(code: str) -> str:
"""
Execute Python code in a sandboxed environment.
Returns the stdout and stderr of the code execution.
Please call print() for any variables you want to see.
If you don't call print(), the output will not be shown.
"""
try:
# Make request to SandboxFusion API
response = requests.post(f"{sandboxfusion_api}/run_code", json={"code": code, "language": "python", "files": {}})

result = response.json()

# Format the output in a pretty way
output_parts = []
output_parts.append("=" * 50)
output_parts.append("CODE EXECUTION RESULT")
output_parts.append("=" * 50)

# Add status
status = result.get("status", "Unknown")
output_parts.append(f"Status: {status}")

if result.get("message"):
output_parts.append(f"Message: {result['message']}")

# Add run result details
run_result = result.get("run_result", {})
if run_result:
output_parts.append(f"Execution Status: {run_result.get('status', 'Unknown')}")
output_parts.append(f"Execution Time: {run_result.get('execution_time', 'N/A')}s")
output_parts.append(f"Return Code: {run_result.get('return_code', 'N/A')}")

# Add stdout if present
stdout = run_result.get("stdout", "")
stdout = stdout.rstrip() or "<No stdout. Use print() to see the output>"
output_parts.append("-" * 30)
output_parts.append("STDOUT:")
output_parts.append("-" * 30)
output_parts.append(stdout.rstrip())

# Add stderr if present
stderr = run_result.get("stderr", "")
stderr = stderr.rstrip() or "<No stderr>"
output_parts.append("-" * 30)
output_parts.append("STDERR:")
output_parts.append("-" * 30)
output_parts.append(stderr.rstrip())

output_parts.append("=" * 50)

return "\n".join(output_parts)

except requests.exceptions.RequestException as e:
return f"Error connecting to SandboxFusion API: {e}"
except json.JSONDecodeError as e:
return f"Error parsing SandboxFusion response: {e}"
except Exception as e:
return f"Unexpected error: {e}"

return code_execution_tool


def assign_initial_state(state: GraphState, max_rounds: int) -> GraphState:
"""Passthrough node to assign initial state variables."""
return {**state, "current_round": 0, "max_rounds": max_rounds, "needs_validation": False}


def react_agent_node(state: GraphState, react_agent) -> GraphState:
"""Node that runs the ReAct agent."""

# Run the agent with current messages
response = react_agent.invoke({"messages": state["messages"]})

return {**state, "messages": response["messages"], "needs_validation": True}


def validation_node(state: GraphState) -> GraphState:
"""Node that appends a deterministic validation message."""

validation_message = HumanMessage(
content="""
Please carefully review your previous response and reasoning. Check the following:

1. Is your reasoning logically sound and well-connected?
2. Are all your calculations correct?
3. Did you use the tools appropriately?
4. Is your final answer accurate and complete?
5. Are there any errors or missing steps?

If you find any issues, please provide a corrected response. If everything looks correct, please confirm your answer is final.
"""
)

updated_messages = state["messages"] + [validation_message]

return {**state, "messages": updated_messages, "current_round": state["current_round"] + 1, "needs_validation": False}


def create_reflection_agent(model: BaseLanguageModel, sandboxfusion_api: str, max_rounds: int = 2) -> StateGraph:
"""
Create a reflection agent that uses deterministic validation messages.

Args:
llm: The language model to use for the ReAct agent
sandboxfusion_api: SandboxFusion API endpoint or identifier
max_rounds: Maximum number of reflection rounds

Returns:
Compiled StateGraph for the reflection agent
"""

# Create tools with SandboxFusion API
tools = [create_sandboxfusion_tool(sandboxfusion_api)]

# Create the ReAct agent
react_agent = create_react_agent(model, tools)

# Create the graph
workflow = StateGraph(GraphState)

# Add nodes
workflow.add_node("assign_state", lambda state: assign_initial_state(state, max_rounds))
workflow.add_node("agent", lambda state: react_agent_node(state, react_agent))
workflow.add_node("validation", validation_node)

# Set entry point to assign_state
workflow.set_entry_point("assign_state")

# Add edge from assign_state to agent
workflow.add_edge("assign_state", "agent")

def should_continue_from_agent(state: GraphState) -> str:
"""Determine what to do after agent generates a response."""

# If we need validation and haven't done max validations yet
if state["needs_validation"] and state["current_round"] < state["max_rounds"] - 1:
return "validate"

# Otherwise end
return "end"

def should_continue_from_validation(state: GraphState) -> str:
"""After validation, always go back to agent for final response."""
return "agent"

# Add conditional edges
workflow.add_conditional_edges("agent", should_continue_from_agent, {"validate": "validation", "end": END})

workflow.add_conditional_edges("validation", should_continue_from_validation, {"agent": "agent"})

# Compile the graph with fallback in case of recursion error
app = workflow.compile()

return app


def run_reflection_agent(question: str, llm: BaseLanguageModel, max_rounds: int = 3, sandboxfusion_api: str = None) -> dict:
"""
Run the reflection agent with a given question.

Args:
question: The question to ask the agent
llm: The language model to use
max_rounds: Maximum number of reflection rounds
sandboxfusion_api: SandboxFusion API endpoint or identifier

Returns:
Dictionary containing the final state and results
"""

# Create the graph with SandboxFusion API
app = create_reflection_agent(llm, sandboxfusion_api, max_rounds)

# Initial state with only messages
initial_state = {"messages": [HumanMessage(content=question)]}

# Run the graph with recursion limit
final_state = app.invoke(initial_state, config={"recursion_limit": 50})

# Extract final answer from the last AI message
final_answer = ""
for message in reversed(final_state["messages"]):
if isinstance(message, AIMessage):
final_answer = message.content
break

return {**final_state, "final_answer": final_answer}


# Example usage
if __name__ == "__main__":
# Example question
question = r"""
Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.

Find the largest possible real part of [(75+117i)z+\frac{96+144i}{z}]where $z$ is a complex number with $|z|=4$.

Remember to put your answer on its own line after "Answer:"."""

print(f"Question: {question}")
print("-" * 50)

# Create a custom LLM
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

# Create mock SandboxFusion API
sandbox_api = "http://localhost:8080"

# Run the reflection agent with SandboxFusion API
result = run_reflection_agent(question, llm=llm, max_rounds=2, sandboxfusion_api=sandbox_api)

print(f"Final Answer: {result['final_answer']}")
print(f"Rounds Completed: {result['current_round']}")

# Print the conversation history
print("\nConversation History:")
for i, message in enumerate(result["messages"]):
message_type = type(message).__name__
content_preview = message.content
tool_calls = hasattr(message, "tool_calls") and message.tool_calls
if tool_calls:
print(f"{i + 1}. {message_type} with Tool Calls: {tool_calls}")
else:
print(f"{i + 1}. {message_type}: {content_preview}")
11 changes: 11 additions & 0 deletions examples/grpo_trainer/config/retool_reflection.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
graph:
_target_: examples.grpo_trainer.config.reflection.create_reflection_agent
sandboxfusion_api: http://sandbox-fusion:8080
chat_template_kwargs:
# See https://huggingface.co/Qwen/Qwen3-32B/discussions/30 for more info
chat_template: "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n\n {{- '<|im_start|>' + message.role }}\n {% generation %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- content }}\n {%- endif %}\n {%- else %}\n {{- content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>' }}\n {% endgeneration %}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
tools:
- _target_: langchain_core.utils.function_calling.convert_to_openai_function
function:
_target_: examples.grpo_trainer.config.reflection.create_sandboxfusion_tool
sandboxfusion_api: dummy
Loading