Skip to content

Commit

Permalink
feat: support link tool
Browse files Browse the repository at this point in the history
  • Loading branch information
plutoless committed Dec 3, 2024
1 parent d35ae84 commit ee62dce
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 66 deletions.
51 changes: 45 additions & 6 deletions agents/examples/default/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
"agora_asr_language": "en-US",
"agora_asr_vendor_key": "${env:AZURE_STT_KEY|}",
"agora_asr_vendor_region": "${env:AZURE_STT_REGION|}",
"agora_asr_session_control_file_path": "session_control.conf"
"agora_asr_session_control_file_path": "session_control.conf",
"subscribe_video_pix_fmt": 4,
"subscribe_video": true
}
},
{
Expand Down Expand Up @@ -83,6 +85,13 @@
"extension_group": "transcriber",
"property": {}
},
{
"type": "extension",
"name": "vision_tool_python",
"addon": "vision_tool_python",
"extension_group": "default",
"property": {}
},
{
"type": "extension",
"name": "weatherapi_tool_python",
Expand Down Expand Up @@ -136,6 +145,17 @@
}
]
}
],
"video_frame": [
{
"name": "video_frame",
"dest": [
{
"extension_group": "default",
"extension": "vision_tool_python"
}
]
}
]
},
{
Expand Down Expand Up @@ -173,6 +193,10 @@
{
"name": "tool_call",
"dest": [
{
"extension_group": "default",
"extension": "vision_tool_python"
},
{
"extension_group": "default",
"extension": "weatherapi_tool_python"
Expand Down Expand Up @@ -263,6 +287,21 @@
}
]
},
{
"extension_group": "default",
"extension": "vision_tool_python",
"cmd": [
{
"name": "tool_register",
"dest": [
{
"extension_group": "chatgpt",
"extension": "llm"
}
]
}
]
},
{
"extension_group": "default",
"extension": "weatherapi_tool_python",
Expand Down Expand Up @@ -327,11 +366,11 @@
},
{
"type": "extension",
"name": "weatherapi_tool_python",
"addon": "weatherapi_tool_python",
"name": "bingsearch_tool_python",
"addon": "bingsearch_tool_python",
"extension_group": "default",
"property": {
"api_key": "${env:WEATHERAPI_API_KEY|}"
"api_key": "${env:BING_API_KEY|}"
}
}
],
Expand Down Expand Up @@ -369,7 +408,7 @@
"dest": [
{
"extension_group": "default",
"extension": "weatherapi_tool_python"
"extension": "bingsearch_tool_python"
}
]
}
Expand Down Expand Up @@ -414,7 +453,7 @@
},
{
"extension_group": "default",
"extension": "weatherapi_tool_python",
"extension": "bingsearch_tool_python",
"cmd": [
{
"name": "tool_register",
Expand Down
75 changes: 55 additions & 20 deletions playground/src/common/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -451,22 +451,22 @@ class GraphEditor {
? connection.cmd.filter((cmd) => cmd.dest?.length > 0)
: undefined;
if (!connection.cmd?.length) delete connection.cmd;

connection.data = Array.isArray(connection.data)
? connection.data.filter((data) => data.dest?.length > 0)
: undefined;
if (!connection.data?.length) delete connection.data;

connection.audio_frame = Array.isArray(connection.audio_frame)
? connection.audio_frame.filter((audio) => audio.dest?.length > 0)
: undefined;
if (!connection.audio_frame?.length) delete connection.audio_frame;

connection.video_frame = Array.isArray(connection.video_frame)
? connection.video_frame.filter((video) => video.dest?.length > 0)
: undefined;
if (!connection.video_frame?.length) delete connection.video_frame;

// Check if at least one protocol remains
return (
connection.cmd?.length ||
Expand All @@ -476,8 +476,8 @@ class GraphEditor {
);
});
}



static removeNodeAndConnections(graph: Graph, addon: string): void {
// Remove the node
Expand Down Expand Up @@ -506,26 +506,61 @@ class GraphEditor {
connection.video_frame?.length)
)
})
// Clean up empty connections
GraphEditor.removeEmptyConnections(graph);
}

/**
* Link a tool to an LLM in the graph
*/
static linkLLMTool(
graph: Graph,
llmExtension: string,
toolExtension: string,
): void {
const llmNode = graph.nodes.find((node) => node.name === llmExtension)
const toolNode = graph.nodes.find((node) => node.name === toolExtension)
* Link a tool to an LLM node by creating the appropriate connections.
*/
static linkTool(graph: Graph, llmNode: Node, toolNode: Node): void {
const llmExtensionGroup = llmNode.extensionGroup;

// Create the connection from the LLM node to the tool node
GraphEditor.addOrUpdateConnection(
graph,
`${llmExtensionGroup}.${llmNode.name}`,
`${toolNode.extensionGroup}.${toolNode.name}`,
GraphConnProtocol.CMD,
"tool_call"
);

// Create the connection from the tool node back to the LLM node
GraphEditor.addOrUpdateConnection(
graph,
`${toolNode.extensionGroup}.${toolNode.name}`,
`${llmExtensionGroup}.${llmNode.name}`,
GraphConnProtocol.CMD,
"tool_register"
);

const rtcModule = GraphEditor.findNodeByPredicate(graph, (node) => node.addon.includes("rtc"));
if (toolNode.addon.includes("vision") && rtcModule) {
// Create the connection from the RTC node to the tool node to deliver video frame
GraphEditor.addOrUpdateConnection(
graph,
`${rtcModule.extensionGroup}.${rtcModule.name}`,
`${toolNode.extensionGroup}.${toolNode.name}`,
GraphConnProtocol.VIDEO_FRAME,
"video_frame"
);
}
}

if (!llmNode || !toolNode) {
throw new Error(
`Either LLM "${llmExtension}" or Tool "${toolExtension}" does not exist in graph "${graph.id}".`,
)
static enableRTCVideoSubscribe(graph: Graph, enabled: Boolean): void {
const rtcNode = GraphEditor.findNodeByPredicate(graph, (node) => node.addon.includes("rtc"));
if (!rtcNode) {
throw new Error("RTC node not found in the graph.");
}

// this.addConnection(graph, llmExtension, toolExtension, "llm_tool_link")
if (enabled) {
GraphEditor.updateNodeProperty(graph, rtcNode.name, {
subscribe_video_pix_fmt: 4,
subscribe_video: true,
});
} else {
GraphEditor.removeNodeProperties(graph, rtcNode.name, ["subscribe_video_pix_fmt", "subscribe_video"]);
}
}
}

Expand Down
43 changes: 3 additions & 40 deletions playground/src/components/Chat/ChatCfgModuleSelect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import {
} from "@/components/ui/form"
import { Button } from "@/components/ui/button"
import { cn } from "@/lib/utils"
import { useAppDispatch, useAppSelector, useGraphs } from "@/common/hooks"
import { useAppSelector, useGraphs } from "@/common/hooks"
import { AddonDef, Graph, Destination, GraphEditor, ProtocolLabel as GraphConnProtocol } from "@/common/graph"
import { toast } from "sonner"
import { BoxesIcon, ChevronRightIcon, LoaderCircleIcon, SettingsIcon, Trash2Icon, WrenchIcon } from "lucide-react"
Expand Down Expand Up @@ -157,17 +157,7 @@ export function RemoteModuleCfgSheet() {

// Process tool modules
if (tools.length > 0) {
if (tools.some((tool) => tool.includes("vision"))) {
GraphEditor.updateNodeProperty(selectedGraphCopy, "agora_rtc", {
subscribe_video_pix_fmt: 4,
subscribe_video: true,
});
} else {
GraphEditor.removeNodeProperties(selectedGraphCopy, "agora_rtc", [
"subscribe_video_pix_fmt",
"subscribe_video",
]);
}
GraphEditor.enableRTCVideoSubscribe(selectedGraphCopy, tools.some((tool) => tool.includes("vision")));

tools.forEach((tool) => {
if (!currentToolsInGraph.includes(tool)) {
Expand All @@ -183,40 +173,13 @@ export function RemoteModuleCfgSheet() {
// Create or update connections
const llmNode = GraphEditor.findNodeByPredicate(selectedGraphCopy, (node) => isLLM(node.name));
if (llmNode) {
const llmExtensionGroup = llmNode.extensionGroup;
GraphEditor.addOrUpdateConnection(
selectedGraphCopy,
`${llmExtensionGroup}.${llmNode.name}`,
`${toolNode.extensionGroup}.${toolNode.name}`,
GraphConnProtocol.CMD,
"tool_call"
);
GraphEditor.addOrUpdateConnection(
selectedGraphCopy,
`${toolNode.extensionGroup}.${toolNode.name}`,
`${llmExtensionGroup}.${llmNode.name}`,
GraphConnProtocol.CMD,
"tool_register"
);
}

if (tool.includes("vision")) {
GraphEditor.addOrUpdateConnection(
selectedGraphCopy,
`${agoraRtcNode.extensionGroup}.${agoraRtcNode.name}`,
`${toolNode.extensionGroup}.${toolNode.name}`,
GraphConnProtocol.VIDEO_FRAME,
"video_frame"
);
GraphEditor.linkTool(selectedGraphCopy, llmNode, toolNode);
}
}
});
needUpdate = true;
}

// Remove empty connections
GraphEditor.removeEmptyConnections(selectedGraphCopy);


// Update graph nodes with selected modules
Object.entries(data).forEach(([key, value]) => {
Expand Down

0 comments on commit ee62dce

Please sign in to comment.