From c91d40e9b95ff270308080a89d9133761cb4dce3 Mon Sep 17 00:00:00 2001 From: Ethan Zhang Date: Thu, 17 Oct 2024 17:46:33 -0700 Subject: [PATCH] feat: support customizing prompts & greeting (#343) * feat: support customizing prompts & greeting * remove unnecessary package * fix react build issue * fix: support mobile & ignore empty values --- demo/package.json | 2 +- demo/src/app/api/agents/start/graph.tsx | 20 ++++- demo/src/app/api/agents/start/route.tsx | 6 +- demo/src/app/home/page.tsx | 6 +- demo/src/common/constant.ts | 7 ++ demo/src/common/request.ts | 8 +- demo/src/common/storage.ts | 25 ++++-- demo/src/components/authInitializer/index.tsx | 9 +- demo/src/components/loginCard/index.tsx | 12 +-- demo/src/components/settings/index.tsx | 90 +++++++++++++++++++ .../src/platform/mobile/description/index.tsx | 5 +- demo/src/platform/mobile/entry/index.tsx | 2 + .../platform/pc/description/index.module.scss | 7 +- demo/src/platform/pc/description/index.tsx | 7 +- demo/src/platform/pc/entry/index.tsx | 4 + demo/src/store/reducers/global.ts | 16 ++-- demo/src/types/index.ts | 5 ++ 17 files changed, 192 insertions(+), 39 deletions(-) create mode 100644 demo/src/components/settings/index.tsx diff --git a/demo/package.json b/demo/package.json index 1a8d6f4f..4f3b2ac6 100644 --- a/demo/package.json +++ b/demo/package.json @@ -13,7 +13,7 @@ "@ant-design/icons": "^5.3.7", "@reduxjs/toolkit": "^2.2.3", "agora-rtc-sdk-ng": "^4.21.0", - "antd": "^5.15.3", + "antd": "^5.21.4", "axios": "^1.7.7", "next": "14.2.4", "protobufjs": "^7.2.5", diff --git a/demo/src/app/api/agents/start/graph.tsx b/demo/src/app/api/agents/start/graph.tsx index 17428e96..f01d6d95 100644 --- a/demo/src/app/api/agents/start/graph.tsx +++ b/demo/src/app/api/agents/start/graph.tsx @@ -61,7 +61,13 @@ export const voiceNameMap: LanguageMap = { // Get the graph properties based on the graph name, language, and voice type // This is the place where you can customize the properties for different graphs to override default property.json -export const getGraphProperties = (graphName: string, language: string, voiceType: string) => { +export const getGraphProperties = ( + graphName: string, + language: string, + voiceType: string, + prompt: string | undefined, + greeting: string | undefined +) => { let localizationOptions = { "greeting": "Hey, I\'m TEN Agent, I can speak, see, and reason from a knowledge base, ask me anything!", "checking_vision_text_items": "[\"Let me take a look...\",\"Let me check your camera...\",\"Please wait for a second...\"]", @@ -91,7 +97,9 @@ export const getGraphProperties = (graphName: string, language: string, voiceTyp }, "openai_chatgpt": { "model": "gpt-4o", - ...localizationOptions + ...localizationOptions, + "prompt": prompt, + "greeting": greeting, }, "azure_tts": { "azure_synthesis_voice_name": voiceNameMap[language]["azure"][voiceType] @@ -103,7 +111,9 @@ export const getGraphProperties = (graphName: string, language: string, voiceTyp "model": "gpt-4o-realtime-preview", "voice": voiceNameMap[language]["openai"][voiceType], "language": language, - ...localizationOptions + ...localizationOptions, + "system_message": prompt, + "greeting": greeting, } } } else if (graphName == "va.openai.azure") { @@ -113,7 +123,9 @@ export const getGraphProperties = (graphName: string, language: string, voiceTyp }, "openai_chatgpt": { "model": "gpt-4o-mini", - ...localizationOptions + ...localizationOptions, + "prompt": prompt, + "greeting": greeting, }, "azure_tts": { "azure_synthesis_voice_name": voiceNameMap[language]["azure"][voiceType] diff --git a/demo/src/app/api/agents/start/route.tsx b/demo/src/app/api/agents/start/route.tsx index 187621a6..ab98ceab 100644 --- a/demo/src/app/api/agents/start/route.tsx +++ b/demo/src/app/api/agents/start/route.tsx @@ -24,6 +24,8 @@ export async function POST(request: NextRequest) { graph_name, language, voice_type, + prompt, + greeting, } = body; console.log(`Starting agent for request ID: ${JSON.stringify({ @@ -32,7 +34,7 @@ export async function POST(request: NextRequest) { user_uid, graph_name, // Get the graph properties based on the graph name, language, and voice type - properties: getGraphProperties(graph_name, language, voice_type), + properties: getGraphProperties(graph_name, language, voice_type, prompt, greeting), })}`); console.log(`AGENT_SERVER_URL: ${AGENT_SERVER_URL}/start`); @@ -44,7 +46,7 @@ export async function POST(request: NextRequest) { user_uid, graph_name, // Get the graph properties based on the graph name, language, and voice type - properties: getGraphProperties(graph_name, language, voice_type), + properties: getGraphProperties(graph_name, language, voice_type, prompt, greeting), }); const responseData = response.data; diff --git a/demo/src/app/home/page.tsx b/demo/src/app/home/page.tsx index e0549df0..c910177a 100644 --- a/demo/src/app/home/page.tsx +++ b/demo/src/app/home/page.tsx @@ -26,9 +26,9 @@ export default function Home() { return ( mobile === null ? <> : - - {mobile ? : } - + + {mobile ? : } + ); } diff --git a/demo/src/common/constant.ts b/demo/src/common/constant.ts index a483df00..00fe177d 100644 --- a/demo/src/common/constant.ts +++ b/demo/src/common/constant.ts @@ -1,11 +1,18 @@ import { IOptions, ColorItem, LanguageOptionItem, VoiceOptionItem, GraphOptionItem } from "@/types" export const GITHUB_URL = "https://github.com/TEN-framework/TEN-Agent" export const OPTIONS_KEY = "__options__" +export const AGENT_SETTINGS_KEY = "__agent_settings__" export const DEFAULT_OPTIONS: IOptions = { channel: "", userName: "", userId: 0 } + +export const DEFAULT_AGENT_SETTINGS = { + greeting: "", + prompt: "" +} + export const DESCRIPTION = "The World's First Multimodal AI Agent with the OpenAI Realtime API (Beta)" export const LANGUAGE_OPTIONS: LanguageOptionItem[] = [ { diff --git a/demo/src/common/request.ts b/demo/src/common/request.ts index 0c65e352..6f6683b9 100644 --- a/demo/src/common/request.ts +++ b/demo/src/common/request.ts @@ -8,6 +8,8 @@ interface StartRequestConfig { graphName: string, language: Language, voiceType: "male" | "female" + prompt?: string, + greeting?: string, } interface GenAgoraDataConfig { @@ -32,14 +34,16 @@ export const apiGenAgoraData = async (config: GenAgoraDataConfig) => { export const apiStartService = async (config: StartRequestConfig): Promise => { // look at app/api/agents/start/route.tsx for the server-side implementation const url = `/api/agents/start` - const { channel, userId, graphName, language, voiceType } = config + const { channel, userId, graphName, language, voiceType, greeting, prompt } = config const data = { request_id: genUUID(), channel_name: channel, user_uid: userId, graph_name: graphName, language, - voice_type: voiceType + voice_type: voiceType, + greeting: greeting ? greeting : undefined, + prompt: prompt ? prompt : undefined } let resp: any = await axios.post(url, data) resp = (resp.data) || {} diff --git a/demo/src/common/storage.ts b/demo/src/common/storage.ts index ed96083d..29bc673c 100644 --- a/demo/src/common/storage.ts +++ b/demo/src/common/storage.ts @@ -1,14 +1,19 @@ -import { IOptions } from "@/types" -import { OPTIONS_KEY, DEFAULT_OPTIONS } from "./constant" +import { IAgentSettings, IOptions } from "@/types" +import { OPTIONS_KEY, DEFAULT_OPTIONS, AGENT_SETTINGS_KEY, DEFAULT_AGENT_SETTINGS } from "./constant" -export const getOptionsFromLocal = () => { +export const getOptionsFromLocal = (): {options:IOptions, settings: IAgentSettings} => { + let data = {options: DEFAULT_OPTIONS, settings: DEFAULT_AGENT_SETTINGS} if (typeof window !== "undefined") { - const data = localStorage.getItem(OPTIONS_KEY) - if (data) { - return JSON.parse(data) + const options = localStorage.getItem(OPTIONS_KEY) + if (options) { + data.options = JSON.parse(options) + } + const settings = localStorage.getItem(AGENT_SETTINGS_KEY) + if (settings) { + data.settings = JSON.parse(settings) } } - return DEFAULT_OPTIONS + return data } @@ -18,4 +23,8 @@ export const setOptionsToLocal = (options: IOptions) => { } } - +export const setAgentSettingsToLocal = (settings: IAgentSettings) => { + if (typeof window !== "undefined") { + localStorage.setItem(AGENT_SETTINGS_KEY, JSON.stringify(settings)) + } +} diff --git a/demo/src/components/authInitializer/index.tsx b/demo/src/components/authInitializer/index.tsx index 5ef763a1..87ead71a 100644 --- a/demo/src/components/authInitializer/index.tsx +++ b/demo/src/components/authInitializer/index.tsx @@ -2,7 +2,7 @@ import { ReactNode, useEffect } from "react" import { useAppDispatch, getOptionsFromLocal } from "@/common" -import { setOptions, reset } from "@/store/reducers/global" +import { setOptions, reset, setAgentSettings } from "@/store/reducers/global" interface AuthInitializerProps { children: ReactNode; @@ -14,10 +14,11 @@ const AuthInitializer = (props: AuthInitializerProps) => { useEffect(() => { if (typeof window !== "undefined") { - const options = getOptionsFromLocal() - if (options) { + const data = getOptionsFromLocal() + if (data) { dispatch(reset()) - dispatch(setOptions(options)) + dispatch(setOptions(data.options)) + dispatch(setAgentSettings(data.settings)) } } }, [dispatch]) diff --git a/demo/src/components/loginCard/index.tsx b/demo/src/components/loginCard/index.tsx index 457511a5..ae5f2aa9 100644 --- a/demo/src/components/loginCard/index.tsx +++ b/demo/src/components/loginCard/index.tsx @@ -1,11 +1,11 @@ "use client" -import type React from 'react'; +import type React from 'react'; import { useRouter } from 'next/navigation' import { message } from "antd" import { useState, useEffect } from "react" import { GITHUB_URL, getRandomUserId, useAppDispatch, getRandomChannel } from "@/common" -import { setOptions } from "@/store/reducers/global" +import { setAgentSettings, setOptions } from "@/store/reducers/global" import styles from "./index.module.scss" import { GithubIcon } from "../icons" @@ -60,8 +60,8 @@ const LoginCard = () => { return
- { if (e.key === 'Enter' || e.key === ' ') { @@ -84,8 +84,8 @@ const LoginCard = () => {
-
{ if (e.key === 'Enter' || e.key === ' ') { diff --git a/demo/src/components/settings/index.tsx b/demo/src/components/settings/index.tsx new file mode 100644 index 00000000..3c3e66b9 --- /dev/null +++ b/demo/src/components/settings/index.tsx @@ -0,0 +1,90 @@ +import React, { useState } from 'react'; +import { Modal, Form, Input, Button, FloatButton, ConfigProvider, theme } from 'antd'; +import { SettingOutlined } from '@ant-design/icons'; +import { useAppDispatch, useAppSelector } from '@/common'; +import { setAgentSettings } from '@/store/reducers/global'; + +interface FormValues { + greeting: string; + prompt: string; +} + +const FormModal: React.FC = () => { + const [isModalVisible, setIsModalVisible] = useState(false); + const [form] = Form.useForm(); + const dispatch = useAppDispatch(); + const agentSettings = useAppSelector(state => state.global.agentSettings); + + const showModal = () => { + form.setFieldsValue(agentSettings); + setIsModalVisible(true); + }; + + const handleOk = async () => { + try { + const values = await form.validateFields(); + console.log('Form Values:', values); + // Handle the form submission logic here + dispatch(setAgentSettings(values)); + setIsModalVisible(false); + form.resetFields(); + } catch (errorInfo) { + console.log('Validate Failed:', errorInfo); + } + }; + + const handleCancel = () => { + setIsModalVisible(false); + }; + + return ( + <> + + } onClick={showModal}> + + +
+ + + + + + + +
+
+
+ + ); +}; + +export default FormModal; diff --git a/demo/src/platform/mobile/description/index.tsx b/demo/src/platform/mobile/description/index.tsx index 7473d550..0241a768 100644 --- a/demo/src/platform/mobile/description/index.tsx +++ b/demo/src/platform/mobile/description/index.tsx @@ -18,6 +18,7 @@ const Description = () => { const language = useAppSelector(state => state.global.language) const voiceType = useAppSelector(state => state.global.voiceType) const graphName = useAppSelector(state => state.global.graphName) + const agentSettings = useAppSelector(state => state.global.agentSettings) const [loading, setLoading] = useState(false) useEffect(() => { @@ -50,7 +51,9 @@ const Description = () => { userId, graphName, language, - voiceType + voiceType, + greeting: agentSettings.greeting, + prompt: agentSettings.prompt }) const { code, msg } = res || {} if (code != 0) { diff --git a/demo/src/platform/mobile/entry/index.tsx b/demo/src/platform/mobile/entry/index.tsx index c5f51d5c..3b19cca7 100644 --- a/demo/src/platform/mobile/entry/index.tsx +++ b/demo/src/platform/mobile/entry/index.tsx @@ -4,6 +4,7 @@ import Rtc from "../rtc" import Header from "../header" import Menu, { IMenuData } from "../menu" import styles from "./index.module.scss" +import FormModal from "@/components/settings" const MenuData: IMenuData[] = [{ @@ -23,6 +24,7 @@ const MobileEntry = () => {
+
} diff --git a/demo/src/platform/pc/description/index.module.scss b/demo/src/platform/pc/description/index.module.scss index 50b29301..acefae3a 100644 --- a/demo/src/platform/pc/description/index.module.scss +++ b/demo/src/platform/pc/description/index.module.scss @@ -45,6 +45,11 @@ caret-color: transparent; box-sizing: border-box; + &.btnSetting { + background: #181A1D; + border: 1px solid #272A2F; + } + .btnText { width: 100px; text-align: center; @@ -70,4 +75,4 @@ border: 1px solid var(--Error-400-T, #E95C7B); } -} +} \ No newline at end of file diff --git a/demo/src/platform/pc/description/index.tsx b/demo/src/platform/pc/description/index.tsx index b95f4658..6d9ec47b 100644 --- a/demo/src/platform/pc/description/index.tsx +++ b/demo/src/platform/pc/description/index.tsx @@ -5,7 +5,7 @@ import { } from "@/common" import { Select, Button, message, Upload } from "antd" import { useEffect, useState, MouseEventHandler } from "react" -import { LoadingOutlined, UploadOutlined } from "@ant-design/icons" +import { LoadingOutlined, SettingFilled } from "@ant-design/icons" import styles from "./index.module.scss" let intervalId: any @@ -18,6 +18,7 @@ const Description = () => { const language = useAppSelector(state => state.global.language) const voiceType = useAppSelector(state => state.global.voiceType) const graphName = useAppSelector(state => state.global.graphName) + const agentSettings = useAppSelector(state => state.global.agentSettings) const [loading, setLoading] = useState(false) useEffect(() => { @@ -50,7 +51,9 @@ const Description = () => { userId, graphName, language, - voiceType + voiceType, + greeting: agentSettings.greeting, + prompt: agentSettings.prompt }) const { code, msg } = res || {} if (code != 0) { diff --git a/demo/src/platform/pc/entry/index.tsx b/demo/src/platform/pc/entry/index.tsx index a7ee7592..ad5a7875 100644 --- a/demo/src/platform/pc/entry/index.tsx +++ b/demo/src/platform/pc/entry/index.tsx @@ -4,6 +4,9 @@ import Rtc from "../rtc" import Header from "../header" import styles from "./index.module.scss" +import { FloatButton, Form } from "antd" +import { SettingOutlined } from "@ant-design/icons" +import FormModal from "@/components/settings" const PCEntry = () => { return
@@ -19,6 +22,7 @@ const PCEntry = () => {
+ } diff --git a/demo/src/store/reducers/global.ts b/demo/src/store/reducers/global.ts index 170f6e86..25a5e29b 100644 --- a/demo/src/store/reducers/global.ts +++ b/demo/src/store/reducers/global.ts @@ -1,6 +1,6 @@ -import { IOptions, IChatItem, Language, VoiceType } from "@/types" +import { IOptions, IChatItem, Language, VoiceType, IAgentSettings } from "@/types" import { createSlice, PayloadAction } from "@reduxjs/toolkit" -import { DEFAULT_OPTIONS, COLOR_LIST, setOptionsToLocal, genRandomChatList } from "@/common" +import { DEFAULT_OPTIONS, COLOR_LIST, setOptionsToLocal, genRandomChatList, DEFAULT_AGENT_SETTINGS, setAgentSettingsToLocal } from "@/common" export interface InitialState { options: IOptions @@ -10,7 +10,8 @@ export interface InitialState { language: Language voiceType: VoiceType chatItems: IChatItem[], - graphName: string + graphName: string, + agentSettings: IAgentSettings } const getInitialState = (): InitialState => { @@ -22,7 +23,8 @@ const getInitialState = (): InitialState => { language: "en-US", voiceType: "male", chatItems: [], - graphName: "va.openai.v2v" + graphName: "va.openai.v2v", + agentSettings: DEFAULT_AGENT_SETTINGS, } } @@ -87,6 +89,10 @@ export const globalSlice = createSlice({ setGraphName: (state, action: PayloadAction) => { state.graphName = action.payload }, + setAgentSettings: (state: { agentSettings: any }, action: PayloadAction>) => { + state.agentSettings = { ...state.agentSettings, ...action.payload } + setAgentSettingsToLocal(state.agentSettings) + }, setVoiceType: (state, action: PayloadAction) => { state.voiceType = action.payload }, @@ -99,7 +105,7 @@ export const globalSlice = createSlice({ export const { reset, setOptions, setRoomConnected, setAgentConnected, setVoiceType, - addChatItem, setThemeColor, setLanguage, setGraphName } = + addChatItem, setThemeColor, setLanguage, setGraphName, setAgentSettings } = globalSlice.actions export default globalSlice.reducer diff --git a/demo/src/types/index.ts b/demo/src/types/index.ts index f5492003..7602b7d6 100644 --- a/demo/src/types/index.ts +++ b/demo/src/types/index.ts @@ -13,6 +13,11 @@ export interface IOptions { userId: number } +export interface IAgentSettings { + prompt: string, + greeting: string +} + export interface IChatItem { userId: number | string,