diff --git a/apps/desktop/src/lib/trpc/routers/ai-chat/index.ts b/apps/desktop/src/lib/trpc/routers/ai-chat/index.ts index 9557f511fe9..19bd320a2fe 100644 --- a/apps/desktop/src/lib/trpc/routers/ai-chat/index.ts +++ b/apps/desktop/src/lib/trpc/routers/ai-chat/index.ts @@ -88,6 +88,23 @@ export const createAiChatRouter = () => { return { success: true }; }), + updateSessionConfig: publicProcedure + .input( + z.object({ + sessionId: z.string(), + maxThinkingTokens: z.number().nullable().optional(), + model: z.string().nullable().optional(), + }), + ) + .mutation(async ({ input }) => { + await chatSessionManager.updateAgentConfig({ + sessionId: input.sessionId, + maxThinkingTokens: input.maxThinkingTokens, + model: input.model, + }); + return { success: true }; + }), + renameSession: publicProcedure .input( z.object({ diff --git a/apps/desktop/src/lib/trpc/routers/ai-chat/utils/session-manager/session-manager.ts b/apps/desktop/src/lib/trpc/routers/ai-chat/utils/session-manager/session-manager.ts index c3daa54f72c..e5ca6268a6f 100644 --- a/apps/desktop/src/lib/trpc/routers/ai-chat/utils/session-manager/session-manager.ts +++ b/apps/desktop/src/lib/trpc/routers/ai-chat/utils/session-manager/session-manager.ts @@ -309,6 +309,68 @@ export class ChatSessionManager extends EventEmitter { await this.store.update(sessionId, patch); } + async updateAgentConfig({ + sessionId, + maxThinkingTokens, + model, + }: { + sessionId: string; + maxThinkingTokens?: number | null; + model?: string | null; + }): Promise { + const session = this.sessions.get(sessionId); + if (!session) { + console.warn( + `[chat/session] Session ${sessionId} not found for config update`, + ); + return; + } + + const registration = this.provider.getAgentRegistration({ + sessionId, + cwd: session.cwd, + }); + + if (maxThinkingTokens !== undefined) { + if (maxThinkingTokens === null) { + delete registration.bodyTemplate.maxThinkingTokens; + } else { + registration.bodyTemplate.maxThinkingTokens = maxThinkingTokens; + } + } + + if (model !== undefined) { + if (model === null) { + delete registration.bodyTemplate.model; + } else { + registration.bodyTemplate.model = model; + } + } + + const headers = buildProxyHeaders(); + const res = await fetch(`${PROXY_URL}/v1/sessions/${sessionId}/agents`, { + method: "POST", + headers, + body: JSON.stringify({ agents: [registration] }), + }); + if (!res.ok) { + throw new Error( + `POST /v1/sessions/${sessionId}/agents failed: ${res.status}`, + ); + } + + console.log( + `[chat/session] Updated agent config for ${sessionId}`, + [ + maxThinkingTokens !== undefined && + `maxThinkingTokens=${maxThinkingTokens}`, + model !== undefined && `model=${model}`, + ] + .filter(Boolean) + .join(", "), + ); + } + isSessionActive(sessionId: string): boolean { return this.sessions.has(sessionId); } diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatPane/ChatInterface/ChatInterface.tsx b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatPane/ChatInterface/ChatInterface.tsx index 861be2b4b76..df99ea86ce3 100644 --- a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatPane/ChatInterface/ChatInterface.tsx +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatPane/ChatInterface/ChatInterface.tsx @@ -16,6 +16,7 @@ import { } from "@superset/ui/ai-elements/prompt-input"; import { Shimmer } from "@superset/ui/ai-elements/shimmer"; import { Suggestion, Suggestions } from "@superset/ui/ai-elements/suggestion"; +import { ThinkingToggle } from "@superset/ui/ai-elements/thinking-toggle"; import { useCallback, useEffect, useRef, useState } from "react"; import { HiMiniAtSymbol, @@ -48,6 +49,9 @@ export function ChatInterface({ }: ChatInterfaceProps) { const [selectedModel, setSelectedModel] = useState(MODELS[1]); const [modelSelectorOpen, setModelSelectorOpen] = useState(false); + const [thinkingEnabled, setThinkingEnabled] = useState(false); + + const updateConfig = electronTrpc.aiChat.updateSessionConfig.useMutation(); const { data: config } = electronTrpc.aiChat.getConfig.useQuery(); @@ -60,6 +64,7 @@ export function ChatInterface({ stop, addToolApprovalResponse, connect, + collections, } = useDurableChat({ sessionId, proxyUrl: config?.proxyUrl ?? "http://localhost:8080", @@ -214,6 +219,28 @@ export function ChatInterface({ [addToolApprovalResponse], ); + const handleThinkingToggle = useCallback( + (enabled: boolean) => { + setThinkingEnabled(enabled); + updateConfig.mutate({ + sessionId, + maxThinkingTokens: enabled ? 10000 : null, + }); + }, + [sessionId, updateConfig], + ); + + const handleModelSelect = useCallback( + (model: ModelOption) => { + setSelectedModel(model); + updateConfig.mutate({ + sessionId, + model: model.id, + }); + }, + [sessionId, updateConfig], + ); + const handleStop = useCallback( (e: React.MouseEvent) => { e.preventDefault(); @@ -296,15 +323,22 @@ export function ChatInterface({ +
- + + q.from({ s: collections.sessionStats }).select(({ s }) => ({ ...s })), + ); + + const stats = statsRows?.[0]; + const usedTokens = stats?.totalTokens ?? 0; + const usage = { + inputTokens: stats?.promptTokens ?? 0, + outputTokens: stats?.completionTokens ?? 0, + totalTokens: stats?.totalTokens ?? 0, + }; -export function ContextIndicator() { return ( @@ -39,8 +48,6 @@ export function ContextIndicator() {
- -
diff --git a/apps/streams/src/claude-agent.ts b/apps/streams/src/claude-agent.ts index 9f5ecec9cfd..3ba30082df9 100644 --- a/apps/streams/src/claude-agent.ts +++ b/apps/streams/src/claude-agent.ts @@ -32,6 +32,8 @@ const agentRequestSchema = z.object({ cwd: z.string().optional(), env: z.record(z.string(), z.string()).optional(), notification: notificationSchema.optional(), + maxThinkingTokens: z.number().optional(), + model: z.string().optional(), }); interface SessionEntry { @@ -169,7 +171,15 @@ app.post("/", async (c) => { ); } - const { messages, sessionId, cwd, env: agentEnv, notification } = parsed.data; + const { + messages, + sessionId, + cwd, + env: agentEnv, + notification, + maxThinkingTokens, + model, + } = parsed.data; const latestUserMessage = messages?.filter((m) => m.role === "user").pop(); @@ -197,7 +207,7 @@ app.post("/", async (c) => { options: { ...(claudeSessionId && { resume: claudeSessionId }), ...(cwd && { cwd }), - model: process.env.CLAUDE_MODEL ?? DEFAULT_MODEL, + model: model ?? process.env.CLAUDE_MODEL ?? DEFAULT_MODEL, maxTurns: MAX_AGENT_TURNS, includePartialMessages: true, permissionMode: "bypassPermissions" as const, @@ -205,6 +215,7 @@ app.post("/", async (c) => { env: queryEnv, abortController, ...(hooks && { hooks }), + ...(maxThinkingTokens !== undefined && { maxThinkingTokens }), }, }); diff --git a/packages/ui/src/components/ai-elements/bash-tool.tsx b/packages/ui/src/components/ai-elements/bash-tool.tsx index f9ac29a9f04..3d3b2847bb5 100644 --- a/packages/ui/src/components/ai-elements/bash-tool.tsx +++ b/packages/ui/src/components/ai-elements/bash-tool.tsx @@ -82,10 +82,7 @@ export const BashTool = ({ if (state === "input-streaming") { return (
diff --git a/packages/ui/src/components/ai-elements/thinking-toggle.tsx b/packages/ui/src/components/ai-elements/thinking-toggle.tsx new file mode 100644 index 00000000000..56380bcdc71 --- /dev/null +++ b/packages/ui/src/components/ai-elements/thinking-toggle.tsx @@ -0,0 +1,55 @@ +"use client"; + +import { BrainIcon } from "lucide-react"; +import type { ComponentProps } from "react"; +import { cn } from "../../lib/utils"; +import { Button } from "../ui/button"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "../ui/tooltip"; + +export type ThinkingToggleProps = Omit< + ComponentProps, + "onClick" | "onToggle" +> & { + enabled: boolean; + onToggle: (enabled: boolean) => void; +}; + +export const ThinkingToggle = ({ + enabled, + onToggle, + className, + ...props +}: ThinkingToggleProps) => ( + + + + + + +

+ {enabled ? "Extended thinking enabled" : "Enable extended thinking"} +

+
+
+
+); diff --git a/packages/ui/src/components/ai-elements/tool-call.tsx b/packages/ui/src/components/ai-elements/tool-call.tsx index 7c1a4815678..631931c9dd1 100644 --- a/packages/ui/src/components/ai-elements/tool-call.tsx +++ b/packages/ui/src/components/ai-elements/tool-call.tsx @@ -29,10 +29,7 @@ export const ToolCall = ({ return (
diff --git a/packages/ui/src/components/ai-elements/tool-interrupted.tsx b/packages/ui/src/components/ai-elements/tool-interrupted.tsx index 93a34d6f047..f5ea23eaa38 100644 --- a/packages/ui/src/components/ai-elements/tool-interrupted.tsx +++ b/packages/ui/src/components/ai-elements/tool-interrupted.tsx @@ -13,12 +13,7 @@ export const ToolInterrupted = ({ subtitle, className, }: ToolInterruptedProps) => ( -
+
{toolName} interrupted