diff --git a/playground/src/common/graph.ts b/playground/src/app/api/agents/start/graph.tsx similarity index 90% rename from playground/src/common/graph.ts rename to playground/src/app/api/agents/start/graph.tsx index 523ee14b5..76a4c9a04 100644 --- a/playground/src/common/graph.ts +++ b/playground/src/app/api/agents/start/graph.tsx @@ -1,4 +1,4 @@ -import { LanguageMap } from "./constant"; +import { LanguageMap } from "@/common/constant"; export const voiceNameMap: LanguageMap = { "zh-CN": { @@ -43,6 +43,8 @@ export const voiceNameMap: LanguageMap = { }, }; +// Get the graph properties based on the graph name, language, and voice type +// This is the place where you can customize the properties for different graphs to override default property.json export const getGraphProperties = (graphName: string, language: string, voiceType: string) => { let localizationOptions = { "greeting": "ASTRA agent connected. How can i help you today?", diff --git a/playground/src/app/api/agents/start/route.tsx b/playground/src/app/api/agents/start/route.tsx new file mode 100644 index 000000000..c83f0e901 --- /dev/null +++ b/playground/src/app/api/agents/start/route.tsx @@ -0,0 +1,50 @@ +import { REQUEST_URL } from '@/common/constant'; +import { NextRequest, NextResponse } from 'next/server'; +import { getGraphProperties } from './graph'; + +/** + * Handles the POST request to start an agent. + * + * @param request - The NextRequest object representing the incoming request. + * @returns A NextResponse object representing the response to be sent back to the client. + */ +export async function POST(request: NextRequest) { + try { + const body = await request.json(); + const { + request_id, + channel_name, + user_uid, + graph_name, + language, + voice_type, + } = body; + + // Send a POST request to start the agent + const response = await fetch(`${REQUEST_URL}/start`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + request_id, + channel_name, + user_uid, + graph_name, + // Get the graph properties based on the graph name, language, and voice type + properties: getGraphProperties(graph_name, language, voice_type), + }), + }); + + const responseData = await response.json(); + + return NextResponse.json(responseData, { status: response.status }); + } catch (error) { + if (error instanceof Response) { + const errorData = await error.json(); + return NextResponse.json(errorData, { status: error.status }); + } else { + return NextResponse.json({ code: "1", data: null, msg: "Internal Server Error" }, { status: 500 }); + } + } +} \ No newline at end of file diff --git a/playground/src/app/api/agents/stop/route.tsx b/playground/src/app/api/agents/stop/route.tsx new file mode 100644 index 000000000..c6c8bcc06 --- /dev/null +++ b/playground/src/app/api/agents/stop/route.tsx @@ -0,0 +1,42 @@ +import { REQUEST_URL } from '@/common/constant'; +import { NextRequest, NextResponse } from 'next/server'; + +/** + * Handles the POST request to stop an agent. + * + * @param request - The NextRequest object representing the incoming request. + * @returns A NextResponse object representing the response to be sent back to the client. + */ +export async function POST(request: NextRequest) { + try { + const body = await request.json(); + const { + channel_name, + request_id, + } = body; + + // Send a POST request to stop the agent + const response = await fetch(`${REQUEST_URL}/stop`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + request_id, + channel_name + }), + }); + + // Get the response data + const responseData = await response.json(); + + return NextResponse.json(responseData, { status: response.status }); + } catch (error) { + if (error instanceof Response) { + const errorData = await error.json(); + return NextResponse.json(errorData, { status: error.status }); + } else { + return NextResponse.json({ code: "1", data: null, msg: "Internal Server Error" }, { status: 500 }); + } + } +} \ No newline at end of file diff --git a/playground/src/common/index.ts b/playground/src/common/index.ts index 48ffa71fc..3c2b0300e 100644 --- a/playground/src/common/index.ts +++ b/playground/src/common/index.ts @@ -4,4 +4,3 @@ export * from "./utils" export * from "./storage" export * from "./request" export * from "./mock" -export * from "./graph" \ No newline at end of file diff --git a/playground/src/common/request.ts b/playground/src/common/request.ts index 10803c6af..3b7592550 100644 --- a/playground/src/common/request.ts +++ b/playground/src/common/request.ts @@ -1,12 +1,14 @@ import { AnyObject } from "antd/es/_util/type" import { REQUEST_URL } from "./constant" import { genUUID } from "./utils" +import { Language } from "@/types" interface StartRequestConfig { channel: string userId: number, - graphName: string - properties: AnyObject + graphName: string, + language: Language, + voiceType: "male" | "female" } interface GenAgoraDataConfig { @@ -34,15 +36,16 @@ export const apiGenAgoraData = async (config: GenAgoraDataConfig) => { } export const apiStartService = async (config: StartRequestConfig): Promise => { - const url = `${REQUEST_URL}/start` - const { channel, userId, graphName, properties } = config + // look at app/api/agents/start/route.tsx for the server-side implementation + const url = `/api/agents/start` + const { channel, userId, graphName, language, voiceType } = config const data = { request_id: genUUID(), channel_name: channel, - openai_proxy_url: "", - remote_stream_id: userId, + user_uid: userId, graph_name: graphName, - properties, + language, + voice_type: voiceType } let resp: any = await fetch(url, { method: "POST", @@ -56,7 +59,8 @@ export const apiStartService = async (config: StartRequestConfig): Promise } export const apiStopService = async (channel: string) => { - const url = `${REQUEST_URL}/stop` + // look at app/api/agents/stop/route.tsx for the server-side implementation + const url = `/api/agents/stop` const data = { request_id: genUUID(), channel_name: channel diff --git a/playground/src/platform/mobile/description/index.tsx b/playground/src/platform/mobile/description/index.tsx index dca6870f0..7473d5503 100644 --- a/playground/src/platform/mobile/description/index.tsx +++ b/playground/src/platform/mobile/description/index.tsx @@ -1,8 +1,7 @@ import { setAgentConnected } from "@/store/reducers/global" import { DESCRIPTION, useAppDispatch, useAppSelector, apiPing, genUUID, - apiStartService, apiStopService, - getGraphProperties + apiStartService, apiStopService } from "@/common" import { message } from "antd" import { useEffect, useState } from "react" @@ -50,7 +49,8 @@ const Description = () => { channel, userId, graphName, - properties: getGraphProperties(graphName, language, voiceType) + language, + voiceType }) const { code, msg } = res || {} if (code != 0) { diff --git a/playground/src/platform/pc/description/index.tsx b/playground/src/platform/pc/description/index.tsx index a93cf5cc0..a9a055cd2 100644 --- a/playground/src/platform/pc/description/index.tsx +++ b/playground/src/platform/pc/description/index.tsx @@ -1,8 +1,7 @@ import { setAgentConnected } from "@/store/reducers/global" import { DESCRIPTION, useAppDispatch, useAppSelector, apiPing, genUUID, - apiStartService, apiStopService, - getGraphProperties + apiStartService, apiStopService } from "@/common" import { Select, Button, message, Upload } from "antd" import { useEffect, useState, MouseEventHandler } from "react" @@ -50,7 +49,8 @@ const Description = () => { channel, userId, graphName, - properties: getGraphProperties(graphName, language, voiceType) + language, + voiceType }) const { code, msg } = res || {} if (code != 0) { diff --git a/server/README.md b/server/README.md new file mode 100644 index 000000000..9ecf41321 --- /dev/null +++ b/server/README.md @@ -0,0 +1,76 @@ +## Request & Response Examples +The server provides a simple layer for managing agent processes. + +### API Resources + + - [POST /start](#get-magazines) + - [POST /stop](#get-magazinesid) + - [POST /ping](#post-magazinesidarticles) + + +### POST /start +This api starts an agent with given graph and override properties. The started agent will join into the specified channel, and subscribe to the uid which your browser/device's rtc use to join. + +| Param | Description | +| -------- | ------- | +| request_id | any uuid for tracing purpose | +| channel_name | channel name, it needs to be the same with the one your browser/device joins, agent needs to stay with your browser/device in the same channel to communicate | +| user_uid | the uid which your browser/device's rtc use to join, agent needs to know your rtc uid to subscribe your audio | +| bot_uid | optional, the uid bot used to join rtc | +| graph_name | the graph to be used when starting agent, will find in property.json | +| properties | additional properties to override in property.json, the override will not change original property.json, only the one agent used to start | +| timeout | determines how long the agent will remain active without receiving any pings. If the timeout is set to `-1`, the agent will not terminate due to inactivity. By default, the timeout is set to 60 seconds, but this can be adjusted using the `WORKER_QUIT_TIMEOUT_SECONDS` variable in your `.env` file. | + +Example: +```bash +curl 'http://localhost:8080/start' \ + -H 'Content-Type: application/json' \ + --data-raw '{ + "request_id": "c1912182-924c-4d15-a8bb-85063343077c", + "channel_name": "test", + "user_uid": 176573, + "graph_name": "camera.va.openai.azure", + "properties": { + "openai_chatgpt": { + "model": "gpt-4o" + } + } + }' +``` + +### POST /stop +This api stops the agent you started + +| Param | Description | +| -------- | ------- | +| request_id | any uuid for tracing purpose | +| channel_name | channel name, the one you used to start the agent | + +Example: +```bash +curl 'http://localhost:8080/stop' \ + -H 'Content-Type: application/json' \ + --data-raw '{ + "request_id": "c1912182-924c-4d15-a8bb-85063343077c", + "channel_name": "test" + }' +``` + + +### POST /ping +This api sends a ping to the server to indicate connection is still alive. This is not needed if you specify `timeout:-1` when starting the agent, otherwise the agent will quit if not receiving ping after timeout in seconds. + +| Param | Description | +| -------- | ------- | +| request_id | any uuid for tracing purpose | +| channel_name | channel name, the one you used to start the agent | + +Example: +```bash +curl 'http://localhost:8080/ping' \ + -H 'Content-Type: application/json' \ + --data-raw '{ + "request_id": "c1912182-924c-4d15-a8bb-85063343077c", + "channel_name": "test" + }' +``` diff --git a/server/internal/config.go b/server/internal/config.go index e0e2edb07..02db66f5e 100644 --- a/server/internal/config.go +++ b/server/internal/config.go @@ -18,6 +18,8 @@ const ( PropertyJsonFile = "./agents/property.json" // Token expire time tokenExpirationInSeconds = uint32(86400) + + WORKER_TIMEOUT_INFINITY = -1 ) var ( @@ -31,6 +33,9 @@ var ( "RemoteStreamId": { {ExtensionName: extensionNameAgoraRTC, Property: "remote_stream_id"}, }, + "BotStreamId": { + {ExtensionName: extensionNameAgoraRTC, Property: "uid"}, + }, "Token": { {ExtensionName: extensionNameAgoraRTC, Property: "token"}, }, diff --git a/server/internal/http_server.go b/server/internal/http_server.go index 4ba9a6608..f3fea8b21 100644 --- a/server/internal/http_server.go +++ b/server/internal/http_server.go @@ -51,10 +51,12 @@ type StartReq struct { RequestId string `json:"request_id,omitempty"` ChannelName string `json:"channel_name,omitempty"` GraphName string `json:"graph_name,omitempty"` - RemoteStreamId uint32 `json:"remote_stream_id,omitempty"` + RemoteStreamId uint32 `json:"user_uid,omitempty"` + BotStreamId uint32 `json:"bot_uid,omitempty"` Token string `json:"token,omitempty"` WorkerHttpServerPort int32 `json:"worker_http_server_port,omitempty"` Properties map[string]map[string]interface{} `json:"properties,omitempty"` + QuitTimeoutSeconds int `json:"timeout,omitempty"` } type StopReq struct { @@ -92,6 +94,22 @@ func (s *HttpServer) handlerHealth(c *gin.Context) { s.output(c, codeOk, nil) } +func (s *HttpServer) handlerList(c *gin.Context) { + slog.Info("handlerList start", logTag) + // Create a slice of maps to hold the filtered data + filtered := make([]map[string]interface{}, len(workers.Keys())) + for _, channelName := range workers.Keys() { + worker := workers.Get(channelName).(*Worker) + workerJson := map[string]interface{}{ + "channelName": worker.ChannelName, + "createTs": worker.CreateTs, + } + filtered = append(filtered, workerJson) + } + slog.Info("handlerList end", logTag) + s.output(c, codeSuccess, filtered) +} + func (s *HttpServer) handlerPing(c *gin.Context) { var req PingReq @@ -163,7 +181,13 @@ func (s *HttpServer) handlerStart(c *gin.Context) { worker := newWorker(req.ChannelName, logFile, s.config.Log2Stdout, propertyJsonFile) worker.HttpServerPort = req.WorkerHttpServerPort - worker.QuitTimeoutSeconds = s.config.WorkerQuitTimeoutSeconds + + if req.QuitTimeoutSeconds > 0 { + worker.QuitTimeoutSeconds = req.QuitTimeoutSeconds + } else { + worker.QuitTimeoutSeconds = s.config.WorkerQuitTimeoutSeconds + } + if err := worker.start(&req); err != nil { slog.Error("handlerStart start worker failed", "err", err, "requestId", req.RequestId, logTag) s.output(c, codeErrStartWorkerFailed, http.StatusInternalServerError) @@ -461,9 +485,10 @@ func (s *HttpServer) Start() { r.GET("/", s.handlerHealth) r.GET("/health", s.handlerHealth) - r.POST("/ping", s.handlerPing) + r.GET("/list", s.handlerList) r.POST("/start", s.handlerStart) r.POST("/stop", s.handlerStop) + r.POST("/ping", s.handlerPing) r.POST("/token/generate", s.handlerGenerateToken) r.GET("/vector/document/preset/list", s.handlerVectorDocumentPresetList) r.POST("/vector/document/update", s.handlerVectorDocumentUpdate) diff --git a/server/internal/worker.go b/server/internal/worker.go index 7462a68dc..8e4dac9c5 100644 --- a/server/internal/worker.go +++ b/server/internal/worker.go @@ -2,6 +2,7 @@ package internal import ( "bufio" + "bytes" "fmt" "io" "log/slog" @@ -203,6 +204,10 @@ func (w *Worker) start(req *StartReq) (err error) { if logFile != nil { logFile.Close() } + + // Remove the worker from the map + workers.Remove(w.ChannelName) + }() return @@ -253,28 +258,72 @@ func (w *Worker) update(req *WorkerUpdateReq) (err error) { return } +// Function to get the PIDs of running workers +func getRunningWorkerPIDs() map[int]struct{} { + // Define the command to find processes + cmd := exec.Command("sh", "-c", `ps aux | grep "bin/worker --property" | grep -v grep`) + + // Run the command and capture the output + var out bytes.Buffer + cmd.Stdout = &out + err := cmd.Run() + if err != nil { + return nil + } + + // Parse the PIDs from the output + lines := strings.Split(out.String(), "\n") + runningPIDs := make(map[int]struct{}) + for _, line := range lines { + fields := strings.Fields(line) + if len(fields) > 1 { + pid, err := strconv.Atoi(fields[1]) // PID is typically the second field + if err == nil { + runningPIDs[pid] = struct{}{} + } + } + } + return runningPIDs +} + +// Function to kill a process by PID +func killProcess(pid int) { + err := syscall.Kill(pid, syscall.SIGKILL) + if err != nil { + slog.Info("Failed to kill process", "pid", pid, "error", err) + } else { + slog.Info("Successfully killed process", "pid", pid) + } +} + func timeoutWorkers() { for { for _, channelName := range workers.Keys() { worker := workers.Get(channelName).(*Worker) + // Skip workers with infinite timeout + if worker.QuitTimeoutSeconds == WORKER_TIMEOUT_INFINITY { + continue + } + nowTs := time.Now().Unix() if worker.UpdateTs+int64(worker.QuitTimeoutSeconds) < nowTs { if err := worker.stop(uuid.New().String(), channelName.(string)); err != nil { - slog.Error("Worker cleanWorker failed", "err", err, "channelName", channelName, logTag) + slog.Error("Timeout worker stop failed", "err", err, "channelName", channelName, logTag) continue } - slog.Info("Worker cleanWorker success", "channelName", channelName, "worker", worker, "nowTs", nowTs, logTag) + slog.Info("Timeout worker stop success", "channelName", channelName, "worker", worker, "nowTs", nowTs, logTag) } } - slog.Debug("Worker cleanWorker sleep", "sleep", workerCleanSleepSeconds, logTag) + slog.Debug("Worker timeout check", "sleep", workerCleanSleepSeconds, logTag) time.Sleep(workerCleanSleepSeconds * time.Second) } } func CleanWorkers() { + // Stop all workers for _, channelName := range workers.Keys() { worker := workers.Get(channelName).(*Worker) if err := worker.stop(uuid.New().String(), channelName.(string)); err != nil { @@ -284,4 +333,22 @@ func CleanWorkers() { slog.Info("Worker cleanWorker success", "channelName", channelName, "worker", worker, logTag) } + + // Get running processes with the specific command pattern + runningPIDs := getRunningWorkerPIDs() + + // Create maps for easy lookup + workerMap := make(map[int]*Worker) + for _, channelName := range workers.Keys() { + worker := workers.Get(channelName).(*Worker) + workerMap[worker.Pid] = worker + } + + // Kill processes that are running but not in the workers list + for pid := range runningPIDs { + if _, exists := workerMap[pid]; !exists { + slog.Info("Killing redundant process", "pid", pid) + killProcess(pid) + } + } }