Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ packaging==25.0
pandas==2.2.3
pillow==11.2.1
psutil==7.0.0
sentence-transformers==2.5.1
scikit-learn==1.5.2
pycparser==2.22
pydantic==2.11.3
pydantic-settings==2.10.1
Expand Down
6 changes: 6 additions & 0 deletions slaver/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
tool:
# Has the model undergone targeted training on tool_calls
support_tool_calls: false
# Tool matching configuration
matching:
# Maximum number of tools to match for each task
max_tools: 3
# Minimum similarity score threshold (0.0 to 1.0)
min_similarity: 0.1

# Cloud Server Infos
model:
Expand Down
27 changes: 25 additions & 2 deletions slaver/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from tools.utils import Config
from tools.tool_matcher import ToolMatcher

config = Config.load_config()
collaborator = Collaborator.from_config(config=config["collaborator"])
Expand All @@ -39,6 +40,12 @@ def __init__(self):
self.threads = []
self.loop = asyncio.get_event_loop()
self.robot_name = None

# Initialize tool matcher with configuration
self.tool_matcher = ToolMatcher(
max_tools=config["tool"]["matching"]["max_tools"],
min_similarity=config["tool"]["matching"]["min_similarity"]
)

signal.signal(signal.SIGINT, self._handle_signal)
signal.signal(signal.SIGTERM, self._handle_signal)
Expand Down Expand Up @@ -108,8 +115,21 @@ async def _execute_task(self, task_data: Dict) -> None:
return

os.makedirs("./.log", exist_ok=True)

# Use tool matcher to find relevant tools for the task
task = task_data["task"]
matched_tools = self.tool_matcher.match_tools(task)

# Filter tools based on matching results
if matched_tools:
matched_tool_names = [tool_name for tool_name, _ in matched_tools]
filtered_tools = [tool for tool in self.tools
if tool.get("function", {}).get("name") in matched_tool_names]
else:
filtered_tools = self.tools

agent = ToolCallingAgent(
tools=self.tools,
tools=filtered_tools,
verbosity_level=2,
model=self.model,
model_path=self.model_path,
Expand All @@ -118,7 +138,7 @@ async def _execute_task(self, task_data: Dict) -> None:
collaborator=self.collaborator,
tool_executor=self.session.call_tool,
)
task = task_data["task"]

result = await agent.run(task)
self._send_result(
robot_name=self.robot_name,
Expand Down Expand Up @@ -200,6 +220,9 @@ async def connect_to_robot(self):
for tool in response.tools
]
print("Connected to robot with tools:", str(self.tools))

# Train the tool matcher with the available tools
self.tool_matcher.fit(self.tools)

"""Complete robot registration with thread management"""
robot_name = config["robot"]["name"]
Expand Down
Loading