Skip to content

Commit

Permalink
feat: change module type to enum
Browse files Browse the repository at this point in the history
  • Loading branch information
plutoless committed Dec 9, 2024
1 parent c2352b2 commit cda3b08
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 36 deletions.
4 changes: 2 additions & 2 deletions playground/src/common/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ const useGraphs = () => {
if (!selectedGraph) {
return null
}
const node = selectedGraph.nodes.find((node) => node.name === nodeName)
const node = selectedGraph.nodes.find((node: Node) => node.name === nodeName)
if (!node) {
return null
}
Expand All @@ -172,7 +172,7 @@ const useGraphs = () => {


const getInstalledAndRegisteredModulesMap = useCallback(() => {
const groupedModules: Record<ModuleRegistry.ModuleType, ModuleRegistry.Module[]> = {
const groupedModules: Record<ModuleRegistry.NonToolModuleType, ModuleRegistry.Module[]> = {
stt: [],
tts: [],
llm: [],
Expand Down
69 changes: 38 additions & 31 deletions playground/src/common/moduleConfig.ts
Original file line number Diff line number Diff line change
@@ -1,121 +1,128 @@
export namespace ModuleRegistry {
export type ModuleType = "stt" | "llm" | "v2v" | "tts";
export type ToolModuleType = "tool";
export enum ModuleType {
STT = "stt",
LLM = "llm",
V2V = "v2v",
TTS = "tts",
TOOL = "tool",
}

export interface Module {
name: string;
type: ModuleType | ToolModuleType;
type: ModuleType;
label: string;
}

export type ToolModule = ModuleRegistry.Module & { type: "tool" };
export type NonToolModuleType = Exclude<ModuleType, ModuleType.TOOL>
export type NonToolModule = Module & { type: NonToolModuleType };
export type ToolModule = Module & { type: ModuleType.TOOL };
}


// Custom labels for specific keys
export const ModuleTypeLabels: Record<ModuleRegistry.ModuleType, string> = {
stt: "STT (Speech to Text)",
llm: "LLM (Large Language Model)",
tts: "TTS (Text to Speech)",
v2v: "LLM v2v (V2V Large Language Model)",
export const ModuleTypeLabels: Record<ModuleRegistry.NonToolModuleType, string> = {
[ModuleRegistry.ModuleType.STT]: "STT (Speech to Text)",
[ModuleRegistry.ModuleType.LLM]: "LLM (Large Language Model)",
[ModuleRegistry.ModuleType.TTS]: "TTS (Text to Speech)",
[ModuleRegistry.ModuleType.V2V]: "LLM v2v (V2V Large Language Model)",
};

export const sttModuleRegistry:Record<string, ModuleRegistry.Module> = {
export const sttModuleRegistry: Record<string, ModuleRegistry.Module> = {
deepgram_asr_python: {
name: "deepgram_asr_python",
type: "stt",
type: ModuleRegistry.ModuleType.STT,
label: "Deepgram STT",
},
transcribe_asr_python: {
name: "transcribe_asr_python",
type: "stt",
type: ModuleRegistry.ModuleType.STT,
label: "Transcribe STT",
}
}

export const llmModuleRegistry:Record<string, ModuleRegistry.Module> = {
export const llmModuleRegistry: Record<string, ModuleRegistry.Module> = {
openai_chatgpt_python: {
name: "openai_chatgpt_python",
type: "llm",
type: ModuleRegistry.ModuleType.LLM,
label: "OpenAI ChatGPT",
},
gemini_llm_python: {
name: "gemini_llm_python",
type: "llm",
type: ModuleRegistry.ModuleType.LLM,
label: "Gemini LLM",
},
bedrock_llm_python: {
name: "bedrock_llm_python",
type: "llm",
type: ModuleRegistry.ModuleType.LLM,
label: "Bedrock LLM",
},
}

export const ttsModuleRegistry:Record<string, ModuleRegistry.Module> = {
export const ttsModuleRegistry: Record<string, ModuleRegistry.Module> = {
azure_tts: {
name: "azure_tts",
type: "tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Azure TTS",
},
cartesia_tts: {
name: "cartesia_tts",
type: "tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Cartesia TTS",
},
cosy_tts_python: {
name: "cosy_tts_python",
type: "tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Cosy TTS",
},
elevenlabs_tts_python: {
name: "elevenlabs_tts_python",
type: "tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Elevenlabs TTS",
},
fish_audio_tts: {
name: "fish_audio_tts",
type: "tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Fish Audio TTS",
},
minimax_tts_python: {
name: "minimax_tts_python",
type: "tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Minimax TTS",
},
polly_tts: {
name: "polly_tts",
type: "tts",
type: ModuleRegistry.ModuleType.TTS,
label: "Polly TTS",
}
}

export const v2vModuleRegistry:Record<string, ModuleRegistry.Module> = {
export const v2vModuleRegistry: Record<string, ModuleRegistry.Module> = {
openai_v2v_python: {
name: "openai_v2v_python",
type: "v2v",
type: ModuleRegistry.ModuleType.V2V,
label: "OpenAI Realtime",
}
}

export const toolModuleRegistry:Record<string, ModuleRegistry.ToolModule> = {
export const toolModuleRegistry: Record<string, ModuleRegistry.ToolModule> = {
vision_analyze_tool_python: {
name: "vision_analyze_tool_python",
type: "tool",
type: ModuleRegistry.ModuleType.TOOL,
label: "Vision Analyze Tool",
},
weatherapi_tool_python: {
name: "weatherapi_tool_python",
type: "tool",
type: ModuleRegistry.ModuleType.TOOL,
label: "WeatherAPI Tool",
},
bingsearch_tool_python: {
name: "bingsearch_tool_python",
type: "tool",
type: ModuleRegistry.ModuleType.TOOL,
label: "BingSearch Tool",
},
vision_tool_python: {
name: "vision_tool_python",
type: "tool",
type: ModuleRegistry.ModuleType.TOOL,
label: "Vision Tool",
},
}
Expand Down
11 changes: 8 additions & 3 deletions playground/src/components/Chat/ChatCfgModuleSelect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export function RemoteModuleCfgSheet() {

if (selectedGraph) {
Object.keys(installedAndRegisteredModulesMap).forEach((key) => {
const moduleTypeKey = key as ModuleRegistry.ModuleType;
const moduleTypeKey = key as ModuleRegistry.NonToolModuleType;

// Check if the current graph has a node whose name contains the ModuleType
const hasMatchingNode = selectedGraph.nodes.some((node) =>
Expand Down Expand Up @@ -252,7 +252,12 @@ const GraphModuleCfgForm = ({
}, [toolModules, selectedGraph]);

// Desired field order
const fieldOrder: ModuleRegistry.ModuleType[] = ["stt", "llm", "v2v", "tts"];
const fieldOrder: ModuleRegistry.NonToolModuleType[] = [
ModuleRegistry.ModuleType.STT,
ModuleRegistry.ModuleType.LLM,
ModuleRegistry.ModuleType.V2V,
ModuleRegistry.ModuleType.TTS,
];
return (
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-6">
Expand All @@ -266,7 +271,7 @@ const GraphModuleCfgForm = ({
render={({ field }) => (
<FormItem>
<FormLabel>
<div className="flex items-center justify-center ">
<div className="flex items-center justify-between ">
<div className="py-3">{ModuleTypeLabels[key]}</div>
{isLLM(key) && (
<DropdownMenu>
Expand Down

0 comments on commit cda3b08

Please sign in to comment.