From 3eb90b52c49dbfb6b80b10a0d05d5ea021c5fb66 Mon Sep 17 00:00:00 2001 From: Ethan Zhang Date: Mon, 9 Dec 2024 11:39:51 +0800 Subject: [PATCH] feat: better module registry (#460) * feat: better module registry * feat: add placeholder when no compatible tools * Update playground/src/components/Chat/ChatCfgModuleSelect.tsx Co-authored-by: czhen <56986964+shczhen@users.noreply.github.com> * feat: change module type to enum * feat: update playground image --------- Co-authored-by: czhen <56986964+shczhen@users.noreply.github.com> --- docker-compose.yml | 2 +- playground/src/common/hooks.ts | 58 +++++- playground/src/common/moduleConfig.ts | 140 +++++++++++++ .../components/Chat/ChatCfgModuleSelect.tsx | 194 +++++++++--------- 4 files changed, 288 insertions(+), 106 deletions(-) create mode 100644 playground/src/common/moduleConfig.ts diff --git a/docker-compose.yml b/docker-compose.yml index 4d24650c..28666d2b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: networks: - ten_agent_network ten_agent_playground: - image: ghcr.io/ten-framework/ten_agent_playground:0.6.1-41-g7292256 + image: ghcr.io/ten-framework/ten_agent_playground:0.6.1-39-gcda3b08 container_name: ten_agent_playground restart: always ports: diff --git a/playground/src/common/hooks.ts b/playground/src/common/hooks.ts index 3b9d8d40..405eec32 100644 --- a/playground/src/common/hooks.ts +++ b/playground/src/common/hooks.ts @@ -7,6 +7,7 @@ import type { AppDispatch, AppStore, RootState } from "../store"; import { useDispatch, useSelector, useStore } from "react-redux"; import { Node, AddonDef, Graph } from "@/common/graph"; import { fetchGraphDetails, initializeGraphData, updateGraph } from "@/store/reducers/global"; +import { moduleRegistry, ModuleRegistry, toolModuleRegistry } from "@/common/moduleConfig"; // import { Grid } from "antd" // const { useBreakpoint } = Grid; @@ -139,15 +140,7 @@ const useGraphs = () => { ) const graphMap = useAppSelector((state) => state.global.graphMap) const selectedGraph = graphMap[selectedGraphId] - const addonModules = useAppSelector((state) => state.global.addonModules) - - // Extract tool modules from addonModules - const toolModules = useMemo( - () => addonModules.filter((module) => module.name.includes("tool") - && module.name !== "vision_analyze_tool_python" - ), - [addonModules], - ) + const addonModules: AddonDef.Module[] = useAppSelector((state) => state.global.addonModules); const initialize = async () => { await dispatch(initializeGraphData()) @@ -168,7 +161,7 @@ const useGraphs = () => { if (!selectedGraph) { return null } - const node = selectedGraph.nodes.find((node) => node.name === nodeName) + const node = selectedGraph.nodes.find((node: Node) => node.name === nodeName) if (!node) { return null } @@ -177,13 +170,56 @@ const useGraphs = () => { [selectedGraph], ) + + const getInstalledAndRegisteredModulesMap = useCallback(() => { + const groupedModules: Record = { + stt: [], + tts: [], + llm: [], + v2v: [] + } + + addonModules.forEach((addonModule) => { + const registeredModule = moduleRegistry[addonModule.name]; + if (registeredModule && registeredModule.type !== "tool") { + groupedModules[registeredModule.type].push(registeredModule); + } + }); + + return groupedModules; + }, [addonModules]); + + const getInstalledAndRegisteredToolModules = useCallback(() => { + const toolModules: ModuleRegistry.ToolModule[] = []; + + addonModules.forEach((addonModule) => { + const registeredModule = toolModuleRegistry[addonModule.name]; + if (registeredModule && registeredModule.type === "tool") { + toolModules.push(registeredModule); + } + }); + + return toolModules; + }, [addonModules]) + + const installedAndRegisteredModulesMap = useMemo( + () => getInstalledAndRegisteredModulesMap(), + [getInstalledAndRegisteredModulesMap], + ); + + const installedAndRegisteredToolModules = useMemo( + () => getInstalledAndRegisteredToolModules(), + [getInstalledAndRegisteredToolModules], + ); + return { initialize, fetchDetails, update, getGraphNodeAddonByName, selectedGraph, - toolModules, + installedAndRegisteredModulesMap, + installedAndRegisteredToolModules, } } diff --git a/playground/src/common/moduleConfig.ts b/playground/src/common/moduleConfig.ts new file mode 100644 index 00000000..0b21c173 --- /dev/null +++ b/playground/src/common/moduleConfig.ts @@ -0,0 +1,140 @@ +export namespace ModuleRegistry { + export enum ModuleType { + STT = "stt", + LLM = "llm", + V2V = "v2v", + TTS = "tts", + TOOL = "tool", + } + + export interface Module { + name: string; + type: ModuleType; + label: string; + } + + export type NonToolModuleType = Exclude + export type NonToolModule = Module & { type: NonToolModuleType }; + export type ToolModule = Module & { type: ModuleType.TOOL }; +} + + +// Custom labels for specific keys +export const ModuleTypeLabels: Record = { + [ModuleRegistry.ModuleType.STT]: "STT (Speech to Text)", + [ModuleRegistry.ModuleType.LLM]: "LLM (Large Language Model)", + [ModuleRegistry.ModuleType.TTS]: "TTS (Text to Speech)", + [ModuleRegistry.ModuleType.V2V]: "LLM v2v (V2V Large Language Model)", +}; + +export const sttModuleRegistry: Record = { + deepgram_asr_python: { + name: "deepgram_asr_python", + type: ModuleRegistry.ModuleType.STT, + label: "Deepgram STT", + }, + transcribe_asr_python: { + name: "transcribe_asr_python", + type: ModuleRegistry.ModuleType.STT, + label: "Transcribe STT", + } +} + +export const llmModuleRegistry: Record = { + openai_chatgpt_python: { + name: "openai_chatgpt_python", + type: ModuleRegistry.ModuleType.LLM, + label: "OpenAI ChatGPT", + }, + gemini_llm_python: { + name: "gemini_llm_python", + type: ModuleRegistry.ModuleType.LLM, + label: "Gemini LLM", + }, + bedrock_llm_python: { + name: "bedrock_llm_python", + type: ModuleRegistry.ModuleType.LLM, + label: "Bedrock LLM", + }, +} + +export const ttsModuleRegistry: Record = { + azure_tts: { + name: "azure_tts", + type: ModuleRegistry.ModuleType.TTS, + label: "Azure TTS", + }, + cartesia_tts: { + name: "cartesia_tts", + type: ModuleRegistry.ModuleType.TTS, + label: "Cartesia TTS", + }, + cosy_tts_python: { + name: "cosy_tts_python", + type: ModuleRegistry.ModuleType.TTS, + label: "Cosy TTS", + }, + elevenlabs_tts_python: { + name: "elevenlabs_tts_python", + type: ModuleRegistry.ModuleType.TTS, + label: "Elevenlabs TTS", + }, + fish_audio_tts: { + name: "fish_audio_tts", + type: ModuleRegistry.ModuleType.TTS, + label: "Fish Audio TTS", + }, + minimax_tts_python: { + name: "minimax_tts_python", + type: ModuleRegistry.ModuleType.TTS, + label: "Minimax TTS", + }, + polly_tts: { + name: "polly_tts", + type: ModuleRegistry.ModuleType.TTS, + label: "Polly TTS", + } +} + +export const v2vModuleRegistry: Record = { + openai_v2v_python: { + name: "openai_v2v_python", + type: ModuleRegistry.ModuleType.V2V, + label: "OpenAI Realtime", + } +} + +export const toolModuleRegistry: Record = { + vision_analyze_tool_python: { + name: "vision_analyze_tool_python", + type: ModuleRegistry.ModuleType.TOOL, + label: "Vision Analyze Tool", + }, + weatherapi_tool_python: { + name: "weatherapi_tool_python", + type: ModuleRegistry.ModuleType.TOOL, + label: "WeatherAPI Tool", + }, + bingsearch_tool_python: { + name: "bingsearch_tool_python", + type: ModuleRegistry.ModuleType.TOOL, + label: "BingSearch Tool", + }, + vision_tool_python: { + name: "vision_tool_python", + type: ModuleRegistry.ModuleType.TOOL, + label: "Vision Tool", + }, +} + +export const moduleRegistry: Record = { + ...sttModuleRegistry, + ...llmModuleRegistry, + ...ttsModuleRegistry, + ...v2vModuleRegistry +} + +export const compatibleTools: Record = { + openai_chatgpt_python: ["vision_tool_python", "weatherapi_tool_python", "bingsearch_tool_python"], + openai_v2v_python: ["weatherapi_tool_python", "bingsearch_tool_python"], +} \ No newline at end of file diff --git a/playground/src/components/Chat/ChatCfgModuleSelect.tsx b/playground/src/components/Chat/ChatCfgModuleSelect.tsx index 4659ab84..dcae871c 100644 --- a/playground/src/components/Chat/ChatCfgModuleSelect.tsx +++ b/playground/src/components/Chat/ChatCfgModuleSelect.tsx @@ -30,7 +30,7 @@ import { } from "@/components/ui/form" import { Button } from "@/components/ui/button" import { cn } from "@/lib/utils" -import { useAppSelector, useGraphs } from "@/common/hooks" +import { useAppSelector, useGraphs, } from "@/common/hooks" import { AddonDef, Graph, Destination, GraphEditor, ProtocolLabel as GraphConnProtocol } from "@/common/graph" import { toast } from "sonner" import { BoxesIcon, ChevronRightIcon, LoaderCircleIcon, SettingsIcon, Trash2Icon, WrenchIcon } from "lucide-react" @@ -39,74 +39,60 @@ import { zodResolver } from "@hookform/resolvers/zod" import { z } from "zod" import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuPortal, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger, DropdownMenuTrigger } from "../ui/dropdown" import { isLLM } from "@/common" +import { compatibleTools, ModuleRegistry, ModuleTypeLabels } from "@/common/moduleConfig" export function RemoteModuleCfgSheet() { const addonModules = useAppSelector((state) => state.global.addonModules); - const { getGraphNodeAddonByName, selectedGraph, update: updateGraph } = useGraphs(); - - const moduleMapping: Record = { - stt: [], - llm: ["openai_chatgpt_python"], - v2v: [], - tts: [], - }; - - // Define the exclusion map for modules - const exclusionMapping: Record = { - stt: [], - llm: ["qwen_llm_python"], - v2v: ["minimax_v2v_python"], - tts: [], - }; - - const modules = React.useMemo(() => { - const result: Record = {}; - - addonModules.forEach((module) => { - const matchingNode = selectedGraph?.nodes.find((node) => - ["stt", "tts", "llm", "v2v"].some((type) => - node.name === type && - (module.name.includes(type) || - (type === "stt" && module.name.includes("asr")) || - (moduleMapping[type]?.includes(module.name))) - ) - ); - - if ( - matchingNode && - !exclusionMapping[matchingNode.name]?.includes(module.name) - ) { - if (!result[matchingNode.name]) { - result[matchingNode.name] = []; - } - result[matchingNode.name].push(module.name); - } - }); - - return result; - }, [addonModules, selectedGraph]); - - const { toolModules } = useGraphs(); + const { getGraphNodeAddonByName, selectedGraph, update: updateGraph, installedAndRegisteredModulesMap, installedAndRegisteredToolModules } = useGraphs(); const metadata = React.useMemo(() => { - const dynamicMetadata: Record = {}; - - Object.keys(modules).forEach((key) => { - dynamicMetadata[key] = { type: "string", options: modules[key] }; - }); + const dynamicMetadata: Record = {}; + + if (selectedGraph) { + Object.keys(installedAndRegisteredModulesMap).forEach((key) => { + const moduleTypeKey = key as ModuleRegistry.NonToolModuleType; + + // Check if the current graph has a node whose name contains the ModuleType + const hasMatchingNode = selectedGraph.nodes.some((node) => + node.name.includes(moduleTypeKey) + ); + + if (hasMatchingNode) { + dynamicMetadata[moduleTypeKey] = { + type: "string", + options: installedAndRegisteredModulesMap[moduleTypeKey].map((module) => ({ + value: module.name, + label: module.label, + })), + }; + } + }); + } return dynamicMetadata; - }, [modules]); + }, [installedAndRegisteredModulesMap, selectedGraph]); const initialData = React.useMemo(() => { const dynamicInitialData: Record = {}; - Object.keys(modules).forEach((key) => { - dynamicInitialData[key] = getGraphNodeAddonByName(key)?.addon; - }); + if (selectedGraph) { + Object.keys(installedAndRegisteredModulesMap).forEach((key) => { + const moduleTypeKey = key as ModuleRegistry.ModuleType; + + // Check if the current graph has a node whose name contains the ModuleType + const hasMatchingNode = selectedGraph.nodes.some((node) => + node.name.includes(moduleTypeKey) + ); + + if (hasMatchingNode) { + dynamicInitialData[moduleTypeKey] = getGraphNodeAddonByName(moduleTypeKey)?.addon; + } + }); + } return dynamicInitialData; - }, [modules, getGraphNodeAddonByName]); + }, [installedAndRegisteredModulesMap, selectedGraph, getGraphNodeAddonByName]); + return ( @@ -137,6 +123,17 @@ export function RemoteModuleCfgSheet() { const nodes = selectedGraphCopy.nodes; let needUpdate = false; + + // Update graph nodes with selected modules + Object.entries(data).forEach(([key, value]) => { + const node = nodes.find((n) => n.name === key); + if (node && value && node.addon !== value) { + node.addon = value; + node.property = addonModules.find((module) => module.name === value)?.defaultProperty; + needUpdate = true; + } + }); + // Retrieve the agora_rtc node const agoraRtcNode = GraphEditor.findNode(selectedGraphCopy, "agora_rtc"); if (!agoraRtcNode) { @@ -146,7 +143,7 @@ export function RemoteModuleCfgSheet() { // Identify removed tools and process them const currentToolsInGraph = nodes - .filter((node) => toolModules.map((module) => module.name).includes(node.addon)) + .filter((node) => installedAndRegisteredToolModules.map((module) => module.name).includes(node.addon)) .map((node) => node.addon); const removedTools = currentToolsInGraph.filter((tool) => !tools.includes(tool)); @@ -180,17 +177,6 @@ export function RemoteModuleCfgSheet() { needUpdate = true; } - - // Update graph nodes with selected modules - Object.entries(data).forEach(([key, value]) => { - const node = nodes.find((n) => n.name === key); - if (node && value && node.addon !== value) { - node.addon = value; - node.property = addonModules.find((module) => module.name === value)?.defaultProperty; - needUpdate = true; - } - }); - // Perform the update if changes are detected if (needUpdate) { try { @@ -217,43 +203,61 @@ const GraphModuleCfgForm = ({ onUpdate, }: { initialData: Record; - metadata: Record; + metadata: Record; onUpdate: (data: Record, tools: string[]) => void; }) => { const formSchema = z.record(z.string(), z.string().nullable()); - const { selectedGraph, toolModules } = useGraphs(); - const form = useForm>({ resolver: zodResolver(formSchema), defaultValues: initialData, }); + const { selectedGraph, installedAndRegisteredToolModules } = useGraphs(); + const { watch } = form; - const onSubmit = (data: z.infer) => { - onUpdate(data, selectedTools); - }; + // Watch for changes in "llm" and "v2v" fields + const llmValue = watch("llm"); + const v2vValue = watch("v2v"); + const toolModules = React.useMemo(() => { + // Step 1: Get installed and registered tool modules + const allToolModules = installedAndRegisteredToolModules || []; + + // Step 2: Determine the active module based on form values + const activeModule = llmValue || v2vValue; + // Step 3: Get compatible tools for the active module + if (activeModule) { + const compatibleToolNames = compatibleTools[activeModule] || []; + return allToolModules.filter((module) => compatibleToolNames.includes(module.name)); + } - // Custom labels for specific keys - const fieldLabels: Record = { - stt: "STT (Speech to Text)", - llm: "LLM (Large Language Model)", - tts: "TTS (Text to Speech)", - v2v: "LLM v2v (V2V Large Language Model)", + // If no LLM or V2V module is selected, return all tool modules + return []; + }, [installedAndRegisteredToolModules, selectedGraph, llmValue, v2vValue]); + + + const onSubmit = (data: z.infer) => { + onUpdate(data, selectedTools); }; + const [selectedTools, setSelectedTools] = React.useState([]); - // Initialize selectedTools by extracting tool addons used in graph nodes - const initialSelectedTools = React.useMemo(() => { + // Synchronize selectedTools with selectedGraph and toolModules + React.useEffect(() => { const toolNames = toolModules.map((module) => module.name); - return selectedGraph?.nodes - .filter((node) => toolNames.includes(node.addon)) - .map((node) => node.addon) || []; + const graphToolAddons = + selectedGraph?.nodes + .filter((node) => toolNames.includes(node.addon)) + .map((node) => node.addon) || []; + setSelectedTools(graphToolAddons); }, [toolModules, selectedGraph]); - const [selectedTools, setSelectedTools] = React.useState(initialSelectedTools); - // Desired field order - const fieldOrder = ["stt", "llm", "v2v", "tts"]; + const fieldOrder: ModuleRegistry.NonToolModuleType[] = [ + ModuleRegistry.ModuleType.STT, + ModuleRegistry.ModuleType.LLM, + ModuleRegistry.ModuleType.V2V, + ModuleRegistry.ModuleType.TTS, + ]; return (
@@ -267,8 +271,8 @@ const GraphModuleCfgForm = ({ render={({ field }) => ( -
-
{fieldLabels[key]}
+
+
{ModuleTypeLabels[key]}
{isLLM(key) && ( - {toolModules.map((module) => ( + {toolModules.length > 0 ? toolModules.map((module) => ( {module.name} - ))} + )) : ( + No compatible tools + )} @@ -320,8 +326,8 @@ const GraphModuleCfgForm = ({ {metadata[key].options.map((option) => ( - - {option} + + {option.label} ))}