diff --git a/README.md b/README.md index d976c8cfd..a130158d1 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ The RAI framework aims to: - [Features](#features) - [Setup](#setup) - [Usage examples (demos)](#simulation-demos) +- [Debugging Assistant](#debugging-assistant) - [Developer resources](#developer-resources) ## Features @@ -66,6 +67,7 @@ The RAI framework aims to: - [x] Improved Human-Robot Interaction with voice and text. - [x] Additional tooling such as GroundingDino. - [x] Support for at least 3 different AI vendors. +- [x] Debugging assistant for ROS 2. - [ ] SDK for RAI developers. - [ ] UI for configuration to select features and tools relevant for your deployment. @@ -153,6 +155,10 @@ Once you know your way around RAI, try the following challenges, with the aid th - Implement additional tools and use them in your interaction. - Try a complex, multi-step task for your robot, such as going to several points to perform observations! +## Debugging Assistant + +Use the [debugging assistant](./docs/debugging_assistant.md) to inspect ROS 2 network state and troubleshoot issues. + ### Simulation demos Try RAI yourself with these demos: diff --git a/docs/debugging_assistant.md b/docs/debugging_assistant.md new file mode 100644 index 000000000..620dd62d2 --- /dev/null +++ b/docs/debugging_assistant.md @@ -0,0 +1,50 @@ +# ROS 2 Debugging Assistant + +The ROS 2 Debugging Assistant is an interactive tool that helps developers inspect and troubleshoot their ROS 2 systems using natural language. It provides a chat-like interface powered by Streamlit where you can ask questions about your ROS 2 setup and execute common debugging commands. + +## Features + +- Interactive chat interface for debugging ROS 2 systems +- Real-time streaming of responses and tool executions +- Support for common ROS 2 debugging commands: + - `ros2 topic`: topic inspection and manipulation + - `ros2 service`: service inspection and calling + - `ros2 node`: node information + - `ros2 action`: action server details and goal sending + - `ros2 interface`: interface inspection + - `ros2 param`: checking and setting parameters + +## Running the Assistant + +1. Make sure you have RAI installed and configured according to the [setup instructions](../README.md#setup) + +2. Launch the debugging assistant: + +```sh +source setup_shell.sh +streamlit run src/rai/rai/tools/debugging_assistant.py +``` + +## Usage Examples + +Here are some example queries you can try: + +- "What topics are currently available?" +- "Show me the message type for /cmd_vel" +- "List all active nodes" +- "What services does the /robot_state_publisher node provide?" +- "Show me information about the /navigate_to_pose action" + +## How it Works + +The debugging assistant uses RAI's conversational agent capabilities combined with ROS 2 debugging tools. The key components are: + +1. **Streamlit Interface**: Provides the chat UI and displays tool execution results +2. **ROS 2 Tools**: Collection of debugging tools that wrap common ROS 2 CLI commands +3. **Streaming Callbacks**: Real-time updates of LLM responses and tool executions + +## Limitations + +- The assistant can only execute safe, read-only commands by default +- Some complex debugging scenarios may require manual intervention +- Performance depends on the chosen LLM vendor and model diff --git a/src/rai/rai/agents/integrations/__init__.py b/src/rai/rai/agents/integrations/__init__.py new file mode 100644 index 000000000..ef74fc891 --- /dev/null +++ b/src/rai/rai/agents/integrations/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/rai/rai/agents/integrations/streamlit.py b/src/rai/rai/agents/integrations/streamlit.py new file mode 100644 index 000000000..128d1556a --- /dev/null +++ b/src/rai/rai/agents/integrations/streamlit.py @@ -0,0 +1,163 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, TypeVar + +import streamlit as st +from langchain_core.callbacks.base import BaseCallbackHandler +from streamlit.delta_generator import DeltaGenerator +from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx + +# code inspired by (mostly copied, some changes were applied and might be updated in the future) +# https://github.com/shiv248/Streamlit-x-LangGraph-Cookbooks/tree/b8e623bdc9821fc1cf581607454dae1afc054df2/tool_calling_via_callback + + +# Define a function to create a callback handler for Streamlit that updates the UI dynamically +def get_streamlit_cb(parent_container: DeltaGenerator) -> BaseCallbackHandler: + """ + Creates a Streamlit callback handler that updates the provided Streamlit container with new tokens. + Args: + parent_container (DeltaGenerator): The Streamlit container where the text will be rendered. + Returns: + BaseCallbackHandler: An instance of a callback handler configured for Streamlit. + """ + + # Define a custom callback handler class for managing and displaying stream events in Streamlit + class StreamHandler(BaseCallbackHandler): + """ + Custom callback handler for Streamlit that updates a Streamlit container with new tokens. + """ + + def __init__( + self, container: st.delta_generator.DeltaGenerator, initial_text: str = "" + ): + """ + Initializes the StreamHandler with a Streamlit container and optional initial text. + Args: + container (st.delta_generator.DeltaGenerator): The Streamlit container where text will be rendered. + initial_text (str): Optional initial text to start with in the container. + """ + self.container = container # The Streamlit container to update + self.thoughts_placeholder = ( + self.container.container() + ) # container to hold tool_call renders + self.tool_output_placeholder = None # placeholder for the output of the tool call to be in the expander + self.token_placeholder = self.container.empty() # for token streaming + self.text = ( + initial_text # The text content to display, starting with initial text + ) + + def on_llm_new_token(self, token: str, **kwargs) -> None: + """ + Callback method triggered when a new token is received (e.g., from a language model). + Args: + token (str): The new token received. + **kwargs: Additional keyword arguments. + """ + self.text += token # Append the new token to the existing text + self.token_placeholder.write(self.text) + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """ + Run when the tool starts running. + Args: + serialized (Dict[str, Any]): The serialized tool. + input_str (str): The input string. + kwargs (Any): Additional keyword arguments. + """ + with self.thoughts_placeholder: + status_placeholder = st.empty() # Placeholder to show the tool's status + with status_placeholder.status("Calling Tool...", expanded=True) as s: + st.write("Called tool: ", serialized["name"]) + st.write("tool description: ", serialized["description"]) + st.write("tool input: ") + st.code(input_str) # Display the input data sent to the tool + st.write("tool output: ") + # Placeholder for tool output that will be updated later below + self.tool_output_placeholder = st.empty() + s.update( + label="Completed Calling Tool!", expanded=False + ) # Update the status once done + + def on_tool_end(self, output: Any, **kwargs: Any) -> Any: + """ + Run when the tool ends. + Args: + output (Any): The output from the tool. + kwargs (Any): Additional keyword arguments. + """ + # We assume that `on_tool_end` comes after `on_tool_start`, meaning output_placeholder exists + if self.tool_output_placeholder: + self.tool_output_placeholder.code( + output.content + ) # Display the tool's output + + # Define a type variable for generic type hinting in the decorator, to maintain + # input function and wrapped function return type + fn_return_type = TypeVar("fn_return_type") + + # Decorator function to add the Streamlit execution context to a function + def add_streamlit_context( + fn: Callable[..., fn_return_type] + ) -> Callable[..., fn_return_type]: + """ + Decorator to ensure that the decorated function runs within the Streamlit execution context. + Args: + fn (Callable[..., fn_return_type]): The function to be decorated. + Returns: + Callable[..., fn_return_type]: The decorated function that includes the Streamlit context setup. + """ + ctx = ( + get_script_run_ctx() + ) # Retrieve the current Streamlit script execution context + + def wrapper(*args, **kwargs) -> fn_return_type: + """ + Wrapper function that adds the Streamlit context and then calls the original function. + Args: + *args: Positional arguments to pass to the original function. + **kwargs: Keyword arguments to pass to the original function. + Returns: + fn_return_type: The result from the original function. + """ + add_script_run_ctx( + ctx=ctx + ) # Add the Streamlit context to the current execution + return fn(*args, **kwargs) # Call the original function with its arguments + + return wrapper + + # Create an instance of the custom StreamHandler with the provided Streamlit container + st_cb = StreamHandler(parent_container) + + # Iterate over all methods of the StreamHandler instance + for method_name, method_func in inspect.getmembers( + st_cb, predicate=inspect.ismethod + ): + if method_name.startswith("on_"): # Identify callback methods + setattr( + st_cb, method_name, add_streamlit_context(method_func) + ) # Wrap and replace the method + + # Return the fully configured StreamHandler instance with the context-aware callback methods + return st_cb + + +def streamlit_invoke(graph, messages, callables): + if not isinstance(callables, list): + raise TypeError("callables must be a list") + return graph.invoke({"messages": messages}, config={"callbacks": callables}) diff --git a/src/rai/rai/tools/debugging_assistant.py b/src/rai/rai/tools/debugging_assistant.py new file mode 100644 index 000000000..2edb28ff8 --- /dev/null +++ b/src/rai/rai/tools/debugging_assistant.py @@ -0,0 +1,85 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import streamlit as st +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from rai.agents.conversational_agent import create_conversational_agent +from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke +from rai.tools.ros.debugging import ( + ros2_action, + ros2_interface, + ros2_node, + ros2_param, + ros2_service, + ros2_topic, +) +from rai.utils.model_initialization import get_llm_model + + +@st.cache_resource +def initialize_graph(): + llm = get_llm_model(model_type="complex_model", streaming=True) + agent = create_conversational_agent( + llm, + [ros2_topic, ros2_interface, ros2_node, ros2_service, ros2_action, ros2_param], + system_prompt="""You are a ROS 2 expert helping a user with their ROS 2 questions. You have access to various tools that allow you to query the ROS 2 system. + Be proactive and use the tools to answer questions. Retrieve as much information from the ROS 2 system as possible. + """, + ) + return agent + + +def main(): + st.set_page_config( + page_title="ROS 2 Debugging Assistant", + page_icon=":robot:", + ) + st.title("ROS 2 Debugging Assistant") + st.markdown("---") + + st.sidebar.header("Tool Calls History") + + if "graph" not in st.session_state: + graph = initialize_graph() + st.session_state["graph"] = graph + + if "messages" not in st.session_state: + st.session_state["messages"] = [ + AIMessage(content="Hi! I am a ROS 2 assistant. How can I help you?") + ] + + prompt = st.chat_input() + for msg in st.session_state.messages: + if isinstance(msg, AIMessage): + if msg.content: + st.chat_message("assistant").write(msg.content) + elif isinstance(msg, HumanMessage): + st.chat_message("user").write(msg.content) + elif isinstance(msg, ToolMessage): + with st.sidebar.expander(f"Tool: {msg.name}", expanded=False): + st.code(msg.content, language="json") + + if prompt: + st.session_state.messages.append(HumanMessage(content=prompt)) + st.chat_message("user").write(prompt) + with st.chat_message("assistant"): + st_callback = get_streamlit_cb(st.container()) + streamlit_invoke( + st.session_state["graph"], st.session_state.messages, [st_callback] + ) + + +if __name__ == "__main__": + main() diff --git a/src/rai/rai/tools/ros/debugging.py b/src/rai/rai/tools/ros/debugging.py new file mode 100644 index 000000000..3c2729caa --- /dev/null +++ b/src/rai/rai/tools/ros/debugging.py @@ -0,0 +1,171 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from subprocess import PIPE, Popen +from threading import Timer +from typing import List, Literal, Optional + +from langchain_core.tools import tool + +FORBIDDEN_CHARACTERS = ["&", ";", "|", "&&", "||", "(", ")", "<", ">", ">>", "<<"] + + +def run_with_timeout(cmd: List[str], timeout_sec: int): + proc = Popen(cmd, stdout=PIPE, stderr=PIPE) + timer = Timer(timeout_sec, proc.kill) + try: + timer.start() + stdout, stderr = proc.communicate() + return stdout, stderr + finally: + timer.cancel() + + +def run_command(cmd: List[str], timeout: int = 5): + # Validate command safety by checking for shell operators + # Block potentially dangerous characters + if any(char in " ".join(cmd) for char in FORBIDDEN_CHARACTERS): + raise ValueError( + "Command is not safe to run. The command contains forbidden characters." + ) + stdout, stderr = run_with_timeout(cmd, timeout) + output = {} + if stdout: + output["stdout"] = stdout.decode("utf-8") + else: + output["stdout"] = "Command returned no stdout output" + if stderr: + output["stderr"] = stderr.decode("utf-8") + else: + output["stderr"] = "Command returned no stderr output" + return str(output) + + +@tool +def ros2_action( + command: Literal["info", "list", "type", "send_goal"], + arguments: Optional[List[str]] = None, + timeout: int = 5, +): + """Run a ROS2 action command + Args: + command: The action command to run (info/list/type) + arguments: Additional arguments for the command as a list of strings + timeout: Command timeout in seconds + """ + cmd = ["ros2", "action", command] + if arguments: + cmd.extend(arguments) + return run_command(cmd, timeout) + + +@tool +def ros2_service( + command: Literal["call", "find", "info", "list", "type"], + arguments: Optional[List[str]] = None, + timeout: int = 5, +): + """Run a ROS2 service command + Args: + command: The service command to run + arguments: Additional arguments for the command as a list of strings + timeout: Command timeout in seconds + """ + cmd = ["ros2", "service", command] + if arguments: + cmd.extend(arguments) + return run_command(cmd, timeout) + + +@tool +def ros2_node( + command: Literal["info", "list"], + arguments: Optional[List[str]] = None, + timeout: int = 5, +): + """Run a ROS2 node command + Args: + command: The node command to run + arguments: Additional arguments for the command as a list of strings + timeout: Command timeout in seconds + """ + cmd = ["ros2", "node", command] + if arguments: + cmd.extend(arguments) + return run_command(cmd, timeout) + + +@tool +def ros2_param( + command: Literal["delete", "describe", "dump", "get", "list", "set"], + arguments: Optional[List[str]] = None, + timeout: int = 5, +): + """Run a ROS2 parameter command + Args: + command: The parameter command to run + arguments: Additional arguments for the command as a list of strings + timeout: Command timeout in seconds + """ + cmd = ["ros2", "param", command] + if arguments: + cmd.extend(arguments) + return run_command(cmd, timeout) + + +@tool +def ros2_interface( + command: Literal["list", "package", "packages", "proto", "show"], + arguments: Optional[List[str]] = None, + timeout: int = 5, +): + """Run a ROS2 interface command + Args: + command: The interface command to run + arguments: Additional arguments for the command as a list of strings + timeout: Command timeout in seconds + """ + cmd = ["ros2", "interface", command] + if arguments: + cmd.extend(arguments) + return run_command(cmd, timeout) + + +@tool +def ros2_topic( + command: Literal[ + "bw", "delay", "echo", "find", "hz", "info", "list", "pub", "type" + ], + arguments: Optional[List[str]] = None, + timeout: int = 5, +): + """Run a ROS2 topic command + Args: + command: The topic command to run: + - bw: Display bandwidth used by topic + - delay: Display delay of topic from timestamp in header + - echo: Output messages from a topic + - find: Output a list of available topics of a given type + - hz: Print the average publishing rate to screen + - info: Print information about a topic + - list: Output a list of available topics + - pub: Publish a message to a topic + - type: Print a topic's type + arguments: Additional arguments for the command as a list of strings + timeout: Command timeout in seconds + """ + cmd = ["ros2", "topic", command] + if arguments: + cmd.extend(arguments) + return run_command(cmd, timeout) diff --git a/src/rai/rai/utils/model_initialization.py b/src/rai/rai/utils/model_initialization.py index 9e1fbd991..dd5af2b47 100644 --- a/src/rai/rai/utils/model_initialization.py +++ b/src/rai/rai/utils/model_initialization.py @@ -94,7 +94,7 @@ def load_config() -> RAIConfig: def get_llm_model( - model_type: Literal["simple_model", "complex_model"], vendor: str = None + model_type: Literal["simple_model", "complex_model"], vendor: str = None, **kwargs ): config = load_config() if vendor is None: @@ -110,18 +110,19 @@ def get_llm_model( if vendor == "openai": from langchain_openai import ChatOpenAI - return ChatOpenAI(model=model) + return ChatOpenAI(model=model, **kwargs) elif vendor == "aws": from langchain_aws import ChatBedrock return ChatBedrock( model_id=model, region_name=model_config.region_name, + **kwargs, ) elif vendor == "ollama": from langchain_ollama import ChatOllama - return ChatOllama(model=model, base_url=model_config.base_url) + return ChatOllama(model=model, base_url=model_config.base_url, **kwargs) else: raise ValueError(f"Unknown LLM vendor: {vendor}")