Skip to content

Commit

Permalink
feat: better module registry (#460)
Browse files Browse the repository at this point in the history
* feat: better module registry

* feat: add placeholder when no compatible tools

* Update playground/src/components/Chat/ChatCfgModuleSelect.tsx

Co-authored-by: czhen <[email protected]>

* feat: change module type to enum

* feat: update playground image

---------

Co-authored-by: czhen <[email protected]>
  • Loading branch information
plutoless and shczhen authored Dec 9, 2024
1 parent d913c2d commit 3eb90b5
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 106 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ services:
networks:
- ten_agent_network
ten_agent_playground:
image: ghcr.io/ten-framework/ten_agent_playground:0.6.1-41-g7292256
image: ghcr.io/ten-framework/ten_agent_playground:0.6.1-39-gcda3b08
container_name: ten_agent_playground
restart: always
ports:
Expand Down
58 changes: 47 additions & 11 deletions playground/src/common/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { AppDispatch, AppStore, RootState } from "../store";
import { useDispatch, useSelector, useStore } from "react-redux";
import { Node, AddonDef, Graph } from "@/common/graph";
import { fetchGraphDetails, initializeGraphData, updateGraph } from "@/store/reducers/global";
import { moduleRegistry, ModuleRegistry, toolModuleRegistry } from "@/common/moduleConfig";
// import { Grid } from "antd"

// const { useBreakpoint } = Grid;
Expand Down Expand Up @@ -139,15 +140,7 @@ const useGraphs = () => {
)
const graphMap = useAppSelector((state) => state.global.graphMap)
const selectedGraph = graphMap[selectedGraphId]
const addonModules = useAppSelector((state) => state.global.addonModules)

// Extract tool modules from addonModules
const toolModules = useMemo(
() => addonModules.filter((module) => module.name.includes("tool")
&& module.name !== "vision_analyze_tool_python"
),
[addonModules],
)
const addonModules: AddonDef.Module[] = useAppSelector((state) => state.global.addonModules);

const initialize = async () => {
await dispatch(initializeGraphData())
Expand All @@ -168,7 +161,7 @@ const useGraphs = () => {
if (!selectedGraph) {
return null
}
const node = selectedGraph.nodes.find((node) => node.name === nodeName)
const node = selectedGraph.nodes.find((node: Node) => node.name === nodeName)
if (!node) {
return null
}
Expand All @@ -177,13 +170,56 @@ const useGraphs = () => {
[selectedGraph],
)


const getInstalledAndRegisteredModulesMap = useCallback(() => {
const groupedModules: Record<ModuleRegistry.NonToolModuleType, ModuleRegistry.Module[]> = {
stt: [],
tts: [],
llm: [],
v2v: []
}

addonModules.forEach((addonModule) => {
const registeredModule = moduleRegistry[addonModule.name];
if (registeredModule && registeredModule.type !== "tool") {
groupedModules[registeredModule.type].push(registeredModule);
}
});

return groupedModules;
}, [addonModules]);

const getInstalledAndRegisteredToolModules = useCallback(() => {
const toolModules: ModuleRegistry.ToolModule[] = [];

addonModules.forEach((addonModule) => {
const registeredModule = toolModuleRegistry[addonModule.name];
if (registeredModule && registeredModule.type === "tool") {
toolModules.push(registeredModule);
}
});

return toolModules;
}, [addonModules])

const installedAndRegisteredModulesMap = useMemo(
() => getInstalledAndRegisteredModulesMap(),
[getInstalledAndRegisteredModulesMap],
);

const installedAndRegisteredToolModules = useMemo(
() => getInstalledAndRegisteredToolModules(),
[getInstalledAndRegisteredToolModules],
);

return {
initialize,
fetchDetails,
update,
getGraphNodeAddonByName,
selectedGraph,
toolModules,
installedAndRegisteredModulesMap,
installedAndRegisteredToolModules,
}
}

Expand Down
140 changes: 140 additions & 0 deletions playground/src/common/moduleConfig.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
export namespace ModuleRegistry {
export enum ModuleType {
STT = "stt",
LLM = "llm",
V2V = "v2v",
TTS = "tts",
TOOL = "tool",
}

export interface Module {
name: string;
type: ModuleType;
label: string;
}

export type NonToolModuleType = Exclude<ModuleType, ModuleType.TOOL>
export type NonToolModule = Module & { type: NonToolModuleType };
export type ToolModule = Module & { type: ModuleType.TOOL };
}


// Custom labels for specific keys
export const ModuleTypeLabels: Record<ModuleRegistry.NonToolModuleType, string> = {
[ModuleRegistry.ModuleType.STT]: "STT (Speech to Text)",
[ModuleRegistry.ModuleType.LLM]: "LLM (Large Language Model)",
[ModuleRegistry.ModuleType.TTS]: "TTS (Text to Speech)",
[ModuleRegistry.ModuleType.V2V]: "LLM v2v (V2V Large Language Model)",
};

export const sttModuleRegistry: Record<string, ModuleRegistry.Module> = {
deepgram_asr_python: {
name: "deepgram_asr_python",
type: ModuleRegistry.ModuleType.STT,
label: "Deepgram STT",
},
transcribe_asr_python: {
name: "transcribe_asr_python",
type: ModuleRegistry.ModuleType.STT,
label: "Transcribe STT",
}
}

export const llmModuleRegistry: Record<string, ModuleRegistry.Module> = {
openai_chatgpt_python: {
name: "openai_chatgpt_python",
type: ModuleRegistry.ModuleType.LLM,
label: "OpenAI ChatGPT",
},
gemini_llm_python: {
name: "gemini_llm_python",
type: ModuleRegistry.ModuleType.LLM,
label: "Gemini LLM",
},
bedrock_llm_python: {
name: "bedrock_llm_python",
type: ModuleRegistry.ModuleType.LLM,
label: "Bedrock LLM",
},
}

export const ttsModuleRegistry: Record<string, ModuleRegistry.Module> = {
azure_tts: {
name: "azure_tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Azure TTS",
},
cartesia_tts: {
name: "cartesia_tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Cartesia TTS",
},
cosy_tts_python: {
name: "cosy_tts_python",
type: ModuleRegistry.ModuleType.TTS,
label: "Cosy TTS",
},
elevenlabs_tts_python: {
name: "elevenlabs_tts_python",
type: ModuleRegistry.ModuleType.TTS,
label: "Elevenlabs TTS",
},
fish_audio_tts: {
name: "fish_audio_tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Fish Audio TTS",
},
minimax_tts_python: {
name: "minimax_tts_python",
type: ModuleRegistry.ModuleType.TTS,
label: "Minimax TTS",
},
polly_tts: {
name: "polly_tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Polly TTS",
}
}

export const v2vModuleRegistry: Record<string, ModuleRegistry.Module> = {
openai_v2v_python: {
name: "openai_v2v_python",
type: ModuleRegistry.ModuleType.V2V,
label: "OpenAI Realtime",
}
}

export const toolModuleRegistry: Record<string, ModuleRegistry.ToolModule> = {
vision_analyze_tool_python: {
name: "vision_analyze_tool_python",
type: ModuleRegistry.ModuleType.TOOL,
label: "Vision Analyze Tool",
},
weatherapi_tool_python: {
name: "weatherapi_tool_python",
type: ModuleRegistry.ModuleType.TOOL,
label: "WeatherAPI Tool",
},
bingsearch_tool_python: {
name: "bingsearch_tool_python",
type: ModuleRegistry.ModuleType.TOOL,
label: "BingSearch Tool",
},
vision_tool_python: {
name: "vision_tool_python",
type: ModuleRegistry.ModuleType.TOOL,
label: "Vision Tool",
},
}

export const moduleRegistry: Record<string, ModuleRegistry.Module> = {
...sttModuleRegistry,
...llmModuleRegistry,
...ttsModuleRegistry,
...v2vModuleRegistry
}

export const compatibleTools: Record<string, string[]> = {
openai_chatgpt_python: ["vision_tool_python", "weatherapi_tool_python", "bingsearch_tool_python"],
openai_v2v_python: ["weatherapi_tool_python", "bingsearch_tool_python"],
}
Loading

0 comments on commit 3eb90b5

Please sign in to comment.