diff --git a/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.ts b/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.ts index 242d91437f7..5d508540789 100644 --- a/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.ts +++ b/apps/desktop/src/lib/trpc/routers/workspaces/utils/ai-name.ts @@ -1,6 +1,5 @@ import { createAnthropic } from "@ai-sdk/anthropic"; import { createOpenAI } from "@ai-sdk/openai"; -import type { Agent } from "@mastra/core/agent"; import { generateTitleFromMessage, getCredentialsFromAnySource as getAnthropicCredentialsFromAnySource, @@ -28,7 +27,10 @@ export type WorkspaceAutoRenameResult = warning?: string; }; -type AgentModel = ConstructorParameters[0]["model"]; +type AgentModel = Extract< + Parameters[0], + { agentModel: unknown } +>["agentModel"]; type AnthropicCredentials = NonNullable< ReturnType >; diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/ChatMastraInterface.tsx b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/ChatMastraInterface.tsx index a76366d36c2..66d44211b47 100644 --- a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/ChatMastraInterface.tsx +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/ChatMastraInterface.tsx @@ -25,20 +25,22 @@ import type { PermissionMode, } from "../../ChatPane/ChatInterface/types"; import { ChatMastraMessageList } from "./components/ChatMastraMessageList"; +import type { UserMessageRestartRequest } from "./components/ChatMastraMessageList/ChatMastraMessageList.types"; import { McpControls } from "./components/McpControls"; import { useMcpUi } from "./hooks/useMcpUi"; import { useOptimisticUpload } from "./hooks/useOptimisticUpload"; import type { ChatMastraInterfaceProps } from "./types"; -import { - hasMatchingUserMessage, - type MastraHistoryMessage, - toOptimisticUserMessage, -} from "./utils/optimisticUserMessage"; +import { toOptimisticUserMessage } from "./utils/optimisticUserMessage"; import { type ChatSendMessageInput, sendMessageForSession, toSendFailureMessage, } from "./utils/sendMessage"; +import { + getVisibleMessagesWithPendingUserTurn, + type PendingUserTurn, + shouldClearPendingUserTurn, +} from "./utils/transientUserTurn"; import { uploadFiles } from "./utils/uploadFiles"; type HarnessFilePayload = { @@ -212,11 +214,14 @@ export function ChatMastraInterface({ const [runtimeError, setRuntimeError] = useState(null); const [interruptedMessage, setInterruptedMessage] = useState(null); - const [pendingImmediateUserMessage, setPendingImmediateUserMessage] = - useState(null); const [approvalResponsePending, setApprovalResponsePending] = useState(false); const [planResponsePending, setPlanResponsePending] = useState(false); const [questionResponsePending, setQuestionResponsePending] = useState(false); + const [editingUserMessageId, setEditingUserMessageId] = useState< + string | null + >(null); + const [pendingUserTurn, setPendingUserTurn] = + useState(null); const currentMcpScopeRef = useRef(null); const consumedLaunchConfigRef = useRef(null); const autoLaunchInFlightRef = useRef(null); @@ -267,6 +272,8 @@ export function ChatMastraInterface({ pendingPlanApproval = null, pendingQuestion = null, } = chat; + const isAwaitingAssistant = + isRunning || submitStatus === "submitted" || submitStatus === "streaming"; const clearRuntimeError = useCallback(() => { setRuntimeError(null); @@ -432,7 +439,8 @@ export function ChatMastraInterface({ setSubmitStatus(undefined); setRuntimeError(null); setInterruptedMessage(null); - setPendingImmediateUserMessage(null); + setPendingUserTurn(null); + setEditingUserMessageId(null); resetMcpUi(); if (sessionId) { void refreshMcpOverview(); @@ -440,31 +448,30 @@ export function ChatMastraInterface({ }, [cwd, refreshMcpOverview, resetMcpUi, sessionId]); useEffect(() => { - if (!pendingImmediateUserMessage) return; if ( - hasMatchingUserMessage({ + shouldClearPendingUserTurn({ messages, - candidate: pendingImmediateUserMessage, + pendingUserTurn, + isAwaitingAssistant, }) ) { - setPendingImmediateUserMessage(null); + setPendingUserTurn(null); } - }, [messages, pendingImmediateUserMessage]); + }, [isAwaitingAssistant, messages, pendingUserTurn]); + + useEffect(() => { + if (!editingUserMessageId) return; + if (messages.some((message) => message.id === editingUserMessageId)) return; + setEditingUserMessageId(null); + }, [editingUserMessageId, messages]); const visibleMessages = useMemo(() => { - if (!pendingImmediateUserMessage) return messages; - if ( - hasMatchingUserMessage({ - messages, - candidate: pendingImmediateUserMessage, - }) - ) { - return messages; - } - return [...messages, pendingImmediateUserMessage]; - }, [messages, pendingImmediateUserMessage]); - const isAwaitingAssistant = - isRunning || submitStatus === "submitted" || submitStatus === "streaming"; + return getVisibleMessagesWithPendingUserTurn({ + messages, + pendingUserTurn, + isAwaitingAssistant, + }); + }, [isAwaitingAssistant, messages, pendingUserTurn]); useEffect(() => { if (isRunning) { @@ -573,7 +580,10 @@ export function ChatMastraInterface({ ? toOptimisticUserMessage(sendInput) : null; if (immediateUserMessage) { - setPendingImmediateUserMessage(immediateUserMessage); + setPendingUserTurn({ + kind: "append", + message: immediateUserMessage, + }); } let targetSessionId = effectiveSessionId; @@ -605,10 +615,11 @@ export function ChatMastraInterface({ setSubmitStatus(undefined); setRuntimeErrorMessage(sendErrorMessage); if (immediateUserMessage) { - setPendingImmediateUserMessage((previousMessage) => - previousMessage?.id === immediateUserMessage.id + setPendingUserTurn((previousTurn) => + previousTurn?.kind === "append" && + previousTurn.message.id === immediateUserMessage.id ? null - : previousMessage, + : previousTurn, ); } if (error instanceof Error) throw error; @@ -780,6 +791,94 @@ export function ChatMastraInterface({ }, [handleSend], ); + const restartFromUserMessage = useCallback( + async ( + request: UserMessageRestartRequest, + options?: { trigger?: "edit" | "resend" }, + ) => { + if (!sessionId) { + throw new Error("Chat session is still starting. Please retry."); + } + + setInterruptedMessage(null); + setPendingUserTurn(null); + setSubmitStatus("submitted"); + clearRuntimeError(); + + const optimisticMessage = toOptimisticUserMessage({ + payload: request.payload, + metadata: { + model: activeModel?.id, + }, + }); + if (optimisticMessage) { + setPendingUserTurn({ + kind: "restart", + prefixMessages: request.prefixMessages, + message: optimisticMessage, + }); + } + + try { + await chatMastraServiceTrpcUtils.client.session.restartFromMessage.mutate( + { + sessionId, + ...(cwd ? { cwd } : {}), + messageId: request.messageId, + payload: request.payload, + metadata: { + model: activeModel?.id, + }, + }, + ); + setEditingUserMessageId(null); + if (request.payload.content) { + onUserMessageSubmitted?.(request.payload.content); + } + captureChatEvent("chat_message_sent", { + session_id: sessionId, + model_id: activeModel?.id ?? null, + mention_count: 0, + attachment_count: request.payload.files?.length ?? 0, + is_slash_command: false, + message_length: request.payload.content.length, + turn_number: (messages?.length ?? 0) + 1, + send_trigger: options?.trigger ?? "resend", + restarted_from_message_id: request.messageId, + }); + } catch (error) { + setPendingUserTurn(null); + const sendErrorMessage = toSendFailureMessage(error); + setSubmitStatus(undefined); + setRuntimeErrorMessage(sendErrorMessage); + if (error instanceof Error) throw error; + throw new Error(sendErrorMessage); + } + }, + [ + activeModel?.id, + captureChatEvent, + chatMastraServiceTrpcUtils.client.session.restartFromMessage, + clearRuntimeError, + cwd, + messages, + onUserMessageSubmitted, + sessionId, + setRuntimeErrorMessage, + ], + ); + const handleResendUserMessage = useCallback( + async (request: UserMessageRestartRequest) => { + await restartFromUserMessage(request, { trigger: "resend" }); + }, + [restartFromUserMessage], + ); + const handleSubmitEditedUserMessage = useCallback( + async (request: UserMessageRestartRequest) => { + await restartFromUserMessage(request, { trigger: "edit" }); + }, + [restartFromUserMessage], + ); const handleApprovalResponse = useCallback( async (decision: "approve" | "decline" | "always_allow_category") => { if (!pendingApproval?.toolCallId) return; @@ -870,6 +969,12 @@ export function ChatMastraInterface({ pendingQuestion={pendingQuestion} isQuestionSubmitting={questionResponsePending} onQuestionRespond={handleQuestionResponse} + editingUserMessageId={editingUserMessageId} + isEditSubmitting={isAwaitingAssistant} + onStartEditUserMessage={setEditingUserMessageId} + onCancelEditUserMessage={() => setEditingUserMessageId(null)} + onSubmitEditedUserMessage={handleSubmitEditedUserMessage} + onRestartUserMessage={handleResendUserMessage} /> {}, + editingUserMessageId: null, + isEditSubmitting: false, + onStartEditUserMessage: () => {}, + onCancelEditUserMessage: () => {}, + onSubmitEditedUserMessage: async () => {}, + onRestartUserMessage: async () => {}, ...overrides, }; } diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/ChatMastraMessageList.tsx b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/ChatMastraMessageList.tsx index 25a6ae362ff..596cb06aa8a 100644 --- a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/ChatMastraMessageList.tsx +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/ChatMastraMessageList.tsx @@ -56,6 +56,12 @@ export function ChatMastraMessageList({ pendingQuestion, isQuestionSubmitting, onQuestionRespond, + editingUserMessageId, + isEditSubmitting, + onStartEditUserMessage, + onCancelEditUserMessage, + onSubmitEditedUserMessage, + onRestartUserMessage, }: ChatMastraMessageListProps) { const messageListRef = useRef(null); const chatSearch = useChatMessageSearch({ @@ -177,14 +183,22 @@ export function ChatMastraMessageList({ icon={} /> ) : ( - renderedMessages.map((message) => { + renderedMessages.map((message, messageIndex) => { if (message.role === "user") { return ( ); } diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/ChatMastraMessageList.types.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/ChatMastraMessageList.types.ts index 47a5feb58dc..c92a9530d56 100644 --- a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/ChatMastraMessageList.types.ts +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/ChatMastraMessageList.types.ts @@ -44,6 +44,22 @@ export interface InterruptedMessagePreview { content: MastraMessage["content"]; } +export interface UserMessageActionPayload { + content: string; + files?: Array<{ + data: string; + mediaType: string; + filename?: string; + uploaded: false; + }>; +} + +export interface UserMessageRestartRequest { + messageId: string; + prefixMessages: MastraMessage[]; + payload: UserMessageActionPayload; +} + export interface ChatMastraMessageListProps { messages: MastraMessage[]; isFocused: boolean; @@ -73,4 +89,12 @@ export interface ChatMastraMessageListProps { pendingQuestion: MastraPendingQuestion; isQuestionSubmitting: boolean; onQuestionRespond: (questionId: string, answer: string) => Promise; + editingUserMessageId: string | null; + isEditSubmitting: boolean; + onStartEditUserMessage: (messageId: string) => void; + onCancelEditUserMessage: () => void; + onSubmitEditedUserMessage: ( + request: UserMessageRestartRequest, + ) => Promise; + onRestartUserMessage: (request: UserMessageRestartRequest) => Promise; } diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/UserMessage.tsx b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/UserMessage.tsx index aaac8543627..60cf70bcb0c 100644 --- a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/UserMessage.tsx +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/UserMessage.tsx @@ -1,33 +1,50 @@ -import type { UseMastraChatDisplayReturn } from "@superset/chat-mastra/client"; -import { Tooltip, TooltipContent, TooltipTrigger } from "@superset/ui/tooltip"; -import { CheckIcon, CopyIcon } from "lucide-react"; import { useCallback, useState } from "react"; import { useTabsStore } from "renderer/stores/tabs/store"; -import { normalizeWorkspaceFilePath } from "../../../../../../ChatPane/ChatInterface/utils/file-paths"; -import { AttachmentChip } from "../AttachmentChip"; -import { parseUserMentions } from "./utils/parseUserMentions"; - -type MastraMessage = NonNullable< - UseMastraChatDisplayReturn["messages"] ->[number]; -type MastraMessagePart = MastraMessage["content"][number]; +import type { + UserMessageActionPayload, + UserMessageRestartRequest, +} from "../../ChatMastraMessageList.types"; +import { UserMessageActions } from "./components/UserMessageActions"; +import { UserMessageAttachments } from "./components/UserMessageAttachments"; +import { UserMessageEditor } from "./components/UserMessageEditor"; +import { UserMessageText } from "./components/UserMessageText"; +import type { MastraMessage } from "./types"; +import { getUserMessageDraft } from "./utils/getUserMessageDraft"; interface UserMessageProps { message: MastraMessage; + prefixMessages: MastraMessage[]; workspaceId: string; workspaceCwd?: string; + isEditing: boolean; + isSubmitting: boolean; + onStartEdit: (messageId: string) => void; + onCancelEdit: () => void; + onSubmitEdit: (request: UserMessageRestartRequest) => Promise; + onRestart: (request: UserMessageRestartRequest) => Promise; + actionDisabled?: boolean; } export function UserMessage({ message, + prefixMessages, workspaceId, workspaceCwd, + isEditing, + isSubmitting, + onStartEdit, + onCancelEdit, + onSubmitEdit, + onRestart, + actionDisabled = false, }: UserMessageProps) { const addFileViewerPane = useTabsStore((store) => store.addFileViewerPane); - const fullText = message.content - .flatMap((part) => (part.type === "text" ? [part.text] : [])) - .join("\n"); + const draft = getUserMessageDraft(message); + const fullText = draft.text; const [copied, setCopied] = useState(false); + const isPersistedMessage = + !message.id.startsWith("optimistic-") && + !message.id.startsWith("immediate-user-message-"); const openAttachment = useCallback( (url: string, filename?: string) => { @@ -57,135 +74,84 @@ export function UserMessage({ }, ); }, [fullText]); + const handleResend = useCallback(() => { + const resendPayload: UserMessageActionPayload = { + content: draft.text, + ...(draft.files.length > 0 + ? { + files: draft.files.map((file) => ({ + data: file.url, + mediaType: file.mediaType, + filename: file.filename, + uploaded: false as const, + })), + } + : {}), + }; + if (!resendPayload.content && !resendPayload.files?.length) { + return; + } + + void onRestart({ + messageId: message.id, + prefixMessages, + payload: resendPayload, + }).catch((error) => { + console.debug("[UserMessage] resend failed", error); + }); + }, [draft.files, draft.text, message.id, onRestart, prefixMessages]); + const showActions = + !isEditing && + Boolean(fullText || draft.files.length > 0) && + isPersistedMessage; return (
- {fullText ? ( - - - - - {!copied ? Copy : null} - + {isEditing ? ( + + onSubmitEdit({ + messageId: message.id, + prefixMessages, + payload, + }) + } + /> ) : null} {message.content.some( (part) => part.type === "image" || (part as { type?: string }).type === "file", - ) && ( -
- {message.content.map((part: MastraMessagePart, partIndex: number) => { - const rawPart = part as { - data?: string; - filename?: string; - mediaType?: string; - mimeType?: string; - type?: string; - }; - if (part.type !== "image" && rawPart.type !== "file") { - return null; - } - - const data = rawPart.data ?? ""; - const mediaType = - rawPart.mediaType ?? - rawPart.mimeType ?? - "application/octet-stream"; - if (!data) { - return null; - } - - if ( - part.type === "image" && - "mimeType" in part && - !rawPart.mediaType - ) { - return ( -
- Attached -
- ); - } - - return ( - openAttachment(data, rawPart.filename)} - /> - ); - })} -
- )} - {message.content.map((part: MastraMessagePart, partIndex: number) => { - if (part.type === "text") { - const mentionSegments = parseUserMentions(part.text); - return ( -
- {mentionSegments.map((segment, segmentIndex) => { - if (segment.type === "text") { - return ( - - {segment.value} - - ); - } - - const normalizedPath = normalizeWorkspaceFilePath({ - filePath: segment.relativePath, - workspaceRoot: workspaceCwd, - }); - const canOpen = Boolean(normalizedPath); - - return ( - - ); - })} -
- ); - } - return null; - })} + ) && + !isEditing && ( + + )} + {!isEditing ? ( + + ) : null} + {showActions ? ( + onStartEdit(message.id)} + onResend={handleResend} + /> + ) : null}
); } diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageActions/UserMessageActions.tsx b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageActions/UserMessageActions.tsx new file mode 100644 index 00000000000..cac288b788b --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageActions/UserMessageActions.tsx @@ -0,0 +1,67 @@ +import { + MessageAction, + MessageActions, +} from "@superset/ui/ai-elements/message"; +import { + CheckIcon, + CopyIcon, + PencilLineIcon, + RotateCcwIcon, +} from "lucide-react"; + +interface UserMessageActionsProps { + actionDisabled: boolean; + copied: boolean; + fullText: string; + onCopy: () => void; + onEdit: () => void; + onResend: () => void; +} + +export function UserMessageActions({ + actionDisabled, + copied, + fullText, + onCopy, + onEdit, + onResend, +}: UserMessageActionsProps) { + return ( +
+ + + + + + + + {fullText ? ( + + {copied ? ( + + ) : ( + + )} + + ) : null} + +
+ ); +} diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageActions/index.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageActions/index.ts new file mode 100644 index 00000000000..8f9150e7bef --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageActions/index.ts @@ -0,0 +1 @@ +export { UserMessageActions } from "./UserMessageActions"; diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageAttachments/UserMessageAttachments.tsx b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageAttachments/UserMessageAttachments.tsx new file mode 100644 index 00000000000..4b46a85999c --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageAttachments/UserMessageAttachments.tsx @@ -0,0 +1,58 @@ +import { AttachmentChip } from "../../../AttachmentChip"; +import type { MastraMessage, MastraMessagePart } from "../../types"; + +interface UserMessageAttachmentsProps { + message: MastraMessage; + onOpenAttachment: (url: string, filename?: string) => void; +} + +export function UserMessageAttachments({ + message, + onOpenAttachment, +}: UserMessageAttachmentsProps) { + return ( +
+ {message.content.map((part: MastraMessagePart, partIndex: number) => { + const rawPart = part as { + data?: string; + filename?: string; + mediaType?: string; + mimeType?: string; + type?: string; + }; + if (part.type !== "image" && rawPart.type !== "file") { + return null; + } + + const data = rawPart.data ?? ""; + const mediaType = + rawPart.mediaType ?? rawPart.mimeType ?? "application/octet-stream"; + if (!data) { + return null; + } + + if (part.type === "image" && "mimeType" in part && !rawPart.mediaType) { + return ( +
+ Attached +
+ ); + } + + return ( + onOpenAttachment(data, rawPart.filename)} + /> + ); + })} +
+ ); +} diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageAttachments/index.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageAttachments/index.ts new file mode 100644 index 00000000000..d9c683dd867 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageAttachments/index.ts @@ -0,0 +1 @@ +export { UserMessageAttachments } from "./UserMessageAttachments"; diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageEditor/UserMessageEditor.tsx b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageEditor/UserMessageEditor.tsx new file mode 100644 index 00000000000..aec5e381d86 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageEditor/UserMessageEditor.tsx @@ -0,0 +1,117 @@ +import { Button } from "@superset/ui/button"; +import { Loader2Icon } from "lucide-react"; +import { useEffect, useRef, useState } from "react"; +import type { UserMessageActionPayload } from "../../../../ChatMastraMessageList.types"; +import { AttachmentChip } from "../../../AttachmentChip"; +import type { UserMessageDraft } from "../../utils/getUserMessageDraft/getUserMessageDraft"; + +interface UserMessageEditorProps { + initialDraft: UserMessageDraft; + isSubmitting: boolean; + onCancel: () => void; + onSubmit: (payload: UserMessageActionPayload) => Promise; +} + +export function UserMessageEditor({ + initialDraft, + isSubmitting, + onCancel, + onSubmit, +}: UserMessageEditorProps) { + const [text, setText] = useState(initialDraft.text); + const inputRef = useRef(null); + const files = initialDraft.files; + + useEffect(() => { + setText(initialDraft.text); + }, [initialDraft.text]); + + useEffect(() => { + const input = inputRef.current; + if (!input) return; + input.focus(); + input.setSelectionRange(input.value.length, input.value.length); + }, []); + + const canSubmit = Boolean(text.trim() || files.length > 0); + const handleSubmit = () => { + if (!canSubmit || isSubmitting) return; + void onSubmit({ + content: text, + ...(files.length > 0 + ? { + files: files.map((file) => ({ + data: file.url, + mediaType: file.mediaType, + filename: file.filename, + uploaded: false as const, + })), + } + : {}), + }); + }; + + return ( +
+ {files.length > 0 ? ( +
+ {files.map((file, index) => ( + + ))} +
+ ) : null} + setText(event.currentTarget.value)} + onKeyDown={(event) => { + if (event.key === "Escape") { + event.preventDefault(); + onCancel(); + return; + } + if (event.key !== "Enter") return; + event.preventDefault(); + handleSubmit(); + }} + placeholder="Edit message..." + className="h-9 w-full rounded-xl border border-transparent bg-muted/45 px-3 text-sm text-foreground outline-none transition-colors placeholder:text-muted-foreground focus:border-border focus:bg-background/70" + /> +
+ + +
+
+ ); +} diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageEditor/index.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageEditor/index.ts new file mode 100644 index 00000000000..401a6b984ec --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageEditor/index.ts @@ -0,0 +1 @@ +export { UserMessageEditor } from "./UserMessageEditor"; diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageText/UserMessageText.tsx b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageText/UserMessageText.tsx new file mode 100644 index 00000000000..eac5bc19b6a --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageText/UserMessageText.tsx @@ -0,0 +1,65 @@ +import { normalizeWorkspaceFilePath } from "../../../../../../../../ChatPane/ChatInterface/utils/file-paths"; +import type { MastraMessage, MastraMessagePart } from "../../types"; +import { parseUserMentions } from "../../utils/parseUserMentions"; + +interface UserMessageTextProps { + message: MastraMessage; + workspaceCwd?: string; + onOpenMentionedFile: (filePath: string) => void; +} + +export function UserMessageText({ + message, + workspaceCwd, + onOpenMentionedFile, +}: UserMessageTextProps) { + return message.content.map((part: MastraMessagePart, partIndex: number) => { + if (part.type !== "text") { + return null; + } + + const mentionSegments = parseUserMentions(part.text); + return ( +
+ {mentionSegments.map((segment, segmentIndex) => { + if (segment.type === "text") { + return ( + + {segment.value} + + ); + } + + const normalizedPath = normalizeWorkspaceFilePath({ + filePath: segment.relativePath, + workspaceRoot: workspaceCwd, + }); + const canOpen = Boolean(normalizedPath); + + return ( + + ); + })} +
+ ); + }); +} diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageText/index.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageText/index.ts new file mode 100644 index 00000000000..5464ea3fbb4 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/components/UserMessageText/index.ts @@ -0,0 +1 @@ +export { UserMessageText } from "./UserMessageText"; diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/types.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/types.ts new file mode 100644 index 00000000000..37e3ef106b2 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/types.ts @@ -0,0 +1,7 @@ +import type { UseMastraChatDisplayReturn } from "@superset/chat-mastra/client"; + +export type MastraMessage = NonNullable< + UseMastraChatDisplayReturn["messages"] +>[number]; + +export type MastraMessagePart = MastraMessage["content"][number]; diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/getUserMessageDraft.test.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/getUserMessageDraft.test.ts new file mode 100644 index 00000000000..bbe286e4082 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/getUserMessageDraft.test.ts @@ -0,0 +1,62 @@ +import { describe, expect, it } from "bun:test"; +import { getUserMessageDraft } from "./getUserMessageDraft"; + +function createMessage( + content: Array>, +): Parameters[0] { + return { + id: "message-1", + role: "user", + content, + createdAt: new Date("2026-03-06T00:00:00.000Z"), + } as Parameters[0]; +} + +describe("getUserMessageDraft", () => { + it("collects text across multiple text parts", () => { + const message = createMessage([ + { type: "text", text: "First line" }, + { type: "text", text: "Second line" }, + ]); + + expect(getUserMessageDraft(message)).toEqual({ + text: "First line\nSecond line", + files: [], + }); + }); + + it("converts files and inline images into prompt-input files", () => { + const message = createMessage([ + { type: "text", text: "Review this" }, + { + type: "file", + data: "https://example.com/spec.pdf", + mediaType: "application/pdf", + filename: "spec.pdf", + }, + { + type: "image", + data: "ZmFrZQ==", + mimeType: "image/png", + }, + ]); + + expect(getUserMessageDraft(message)).toEqual({ + text: "Review this", + files: [ + { + type: "file", + url: "https://example.com/spec.pdf", + mediaType: "application/pdf", + filename: "spec.pdf", + }, + { + type: "file", + url: "data:image/png;base64,ZmFrZQ==", + mediaType: "image/png", + filename: undefined, + }, + ], + }); + }); +}); diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/getUserMessageDraft.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/getUserMessageDraft.ts new file mode 100644 index 00000000000..5fd60f85d07 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/getUserMessageDraft.ts @@ -0,0 +1,81 @@ +import type { UseMastraChatDisplayReturn } from "@superset/chat-mastra/client"; +import type { FileUIPart } from "ai"; + +type MastraMessage = NonNullable< + UseMastraChatDisplayReturn["messages"] +>[number]; +type MastraMessagePart = MastraMessage["content"][number]; + +interface AttachmentSource { + url: string; + mediaType: string; + filename?: string; +} + +export interface UserMessageDraft { + text: string; + files: FileUIPart[]; +} + +function getUserMessageText(message: MastraMessage): string { + return message.content + .flatMap((part) => (part.type === "text" ? [part.text] : [])) + .join("\n"); +} + +function toAttachmentSource(part: MastraMessagePart): AttachmentSource | null { + const rawPart = part as { + data?: string; + filename?: string; + image?: string; + mediaType?: string; + mimeType?: string; + type?: string; + }; + + if (part.type !== "image" && rawPart.type !== "file") { + return null; + } + + const mediaType = + rawPart.mediaType ?? rawPart.mimeType ?? "application/octet-stream"; + const data = rawPart.data ?? rawPart.image ?? ""; + if (!data) { + return null; + } + + if (part.type === "image" && "mimeType" in part && !rawPart.mediaType) { + return { + url: `data:${part.mimeType};base64,${part.data}`, + mediaType: part.mimeType, + filename: rawPart.filename, + }; + } + + return { + url: data, + mediaType, + filename: rawPart.filename, + }; +} + +function getUserMessageAttachmentSources( + message: MastraMessage, +): AttachmentSource[] { + return message.content.flatMap((part) => { + const attachment = toAttachmentSource(part); + return attachment ? [attachment] : []; + }); +} + +export function getUserMessageDraft(message: MastraMessage): UserMessageDraft { + return { + text: getUserMessageText(message), + files: getUserMessageAttachmentSources(message).map((attachment) => ({ + type: "file", + url: attachment.url, + mediaType: attachment.mediaType, + filename: attachment.filename, + })), + }; +} diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/index.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/index.ts new file mode 100644 index 00000000000..91d2ecebcc1 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/components/ChatMastraMessageList/components/UserMessage/utils/getUserMessageDraft/index.ts @@ -0,0 +1 @@ +export { getUserMessageDraft } from "./getUserMessageDraft"; diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/index.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/index.ts new file mode 100644 index 00000000000..f3203074406 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/index.ts @@ -0,0 +1,5 @@ +export { + getVisibleMessagesWithPendingUserTurn, + type PendingUserTurn, + shouldClearPendingUserTurn, +} from "./transientUserTurn"; diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/transientUserTurn.test.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/transientUserTurn.test.ts new file mode 100644 index 00000000000..5310180b7ab --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/transientUserTurn.test.ts @@ -0,0 +1,106 @@ +import { describe, expect, it } from "bun:test"; +import { + getVisibleMessagesWithPendingUserTurn, + type PendingUserTurn, + shouldClearPendingUserTurn, +} from "./transientUserTurn"; + +type TestMessage = { + id: string; + role: "user" | "assistant"; + content: Array<{ type: "text"; text: string }>; + createdAt: Date; +}; + +function message( + id: string, + role: TestMessage["role"], + text: string, + createdAt = "2026-03-07T01:00:00.000Z", +): TestMessage { + return { + id, + role, + content: [{ type: "text", text }], + createdAt: new Date(createdAt), + }; +} + +describe("getVisibleMessagesWithPendingUserTurn", () => { + it("appends a pending composer send until it is persisted", () => { + const messages = [message("u1", "user", "hello")] as TestMessage[]; + const pendingUserTurn: PendingUserTurn = { + kind: "append", + message: message("optimistic-1", "user", "follow up"), + }; + + expect( + getVisibleMessagesWithPendingUserTurn({ + messages: messages as never, + pendingUserTurn: pendingUserTurn as never, + isAwaitingAssistant: true, + }), + ).toHaveLength(2); + }); + + it("keeps the rendered prefix while a restarted turn is streaming", () => { + const persistedMessages = [ + message("u1", "user", "hey bos"), + message("a1", "assistant", "Hey! What can I help you with today?"), + message("u2", "user", "whats your model?"), + ] as TestMessage[]; + const pendingUserTurn: PendingUserTurn = { + kind: "restart", + prefixMessages: persistedMessages.slice(0, 2) as never, + message: message("optimistic-2", "user", "whats your model?"), + }; + + expect( + getVisibleMessagesWithPendingUserTurn({ + messages: persistedMessages as never, + pendingUserTurn: pendingUserTurn as never, + isAwaitingAssistant: true, + }), + ).toEqual([ + persistedMessages[0], + persistedMessages[1], + pendingUserTurn.message, + ]); + }); +}); + +describe("shouldClearPendingUserTurn", () => { + it("does not clear a restart overlay while the assistant is still pending", () => { + const messages = [message("u1", "user", "hello")] as TestMessage[]; + const pendingUserTurn: PendingUserTurn = { + kind: "restart", + prefixMessages: [], + message: message("optimistic-1", "user", "hello"), + }; + + expect( + shouldClearPendingUserTurn({ + messages: messages as never, + pendingUserTurn: pendingUserTurn as never, + isAwaitingAssistant: true, + }), + ).toBe(false); + }); + + it("clears a restart overlay once the restarted user message is persisted and streaming is done", () => { + const messages = [message("u1", "user", "hello")] as TestMessage[]; + const pendingUserTurn: PendingUserTurn = { + kind: "restart", + prefixMessages: [], + message: message("optimistic-1", "user", "hello"), + }; + + expect( + shouldClearPendingUserTurn({ + messages: messages as never, + pendingUserTurn: pendingUserTurn as never, + isAwaitingAssistant: false, + }), + ).toBe(true); + }); +}); diff --git a/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/transientUserTurn.ts b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/transientUserTurn.ts new file mode 100644 index 00000000000..ed150d428b2 --- /dev/null +++ b/apps/desktop/src/renderer/screens/main/components/WorkspaceView/ContentView/TabsContent/TabView/ChatMastraPane/ChatMastraInterface/utils/transientUserTurn/transientUserTurn.ts @@ -0,0 +1,71 @@ +import { + hasMatchingUserMessage, + type MastraHistoryMessage, +} from "../optimisticUserMessage"; + +export type PendingUserTurn = + | { + kind: "append"; + message: MastraHistoryMessage; + } + | { + kind: "restart"; + message: MastraHistoryMessage; + prefixMessages: MastraHistoryMessage[]; + }; + +export function shouldClearPendingUserTurn({ + messages, + pendingUserTurn, + isAwaitingAssistant, +}: { + messages: MastraHistoryMessage[]; + pendingUserTurn: PendingUserTurn | null; + isAwaitingAssistant: boolean; +}): boolean { + if (!pendingUserTurn) return false; + if ( + !hasMatchingUserMessage({ + messages, + candidate: pendingUserTurn.message, + }) + ) { + return false; + } + + if (pendingUserTurn.kind === "restart" && isAwaitingAssistant) { + return false; + } + + return true; +} + +export function getVisibleMessagesWithPendingUserTurn({ + messages, + pendingUserTurn, + isAwaitingAssistant, +}: { + messages: MastraHistoryMessage[]; + pendingUserTurn: PendingUserTurn | null; + isAwaitingAssistant: boolean; +}): MastraHistoryMessage[] { + if (!pendingUserTurn) return messages; + + const hasPersistedMessage = hasMatchingUserMessage({ + messages, + candidate: pendingUserTurn.message, + }); + + if (pendingUserTurn.kind === "restart") { + if (isAwaitingAssistant || !hasPersistedMessage) { + return [...pendingUserTurn.prefixMessages, pendingUserTurn.message]; + } + return messages; + } + + if (hasPersistedMessage) { + return messages; + } + + return [...messages, pendingUserTurn.message]; +} diff --git a/packages/chat-mastra/src/server/trpc/service.ts b/packages/chat-mastra/src/server/trpc/service.ts index 1cb82d6b1a8..713b5729aeb 100644 --- a/packages/chat-mastra/src/server/trpc/service.ts +++ b/packages/chat-mastra/src/server/trpc/service.ts @@ -12,6 +12,7 @@ import { onUserPromptSubmit, type RuntimeSession, reloadHookConfig, + restartRuntimeFromUserMessage, runSessionStartHook, subscribeToSessionEvents, } from "./utils/runtime"; @@ -24,6 +25,7 @@ import { mcpServerAuthInput, planRespondInput, questionRespondInput, + restartFromMessageInput, searchFilesInput, sendMessageInput, sessionIdInput, @@ -246,6 +248,31 @@ export class ChatMastraService { return runtime.harness.sendMessage(input.payload); }), + restartFromMessage: t.procedure + .input(restartFromMessageInput) + .mutation(async ({ input }) => { + const runtime = await this.getOrCreateRuntime( + input.sessionId, + input.cwd, + ); + runtime.lastErrorMessage = null; + const userMessage = + input.payload.content.trim() || "[non-text message]"; + await onUserPromptSubmit(runtime, userMessage); + const submittedUserMessage = input.payload.content.trim(); + await restartRuntimeFromUserMessage(runtime, { + messageId: input.messageId, + payload: input.payload, + metadata: input.metadata, + }); + void generateAndSetTitle(runtime, this.apiClient, { + submittedUserMessage: + submittedUserMessage.length > 0 + ? submittedUserMessage + : undefined, + }); + }), + stop: t.procedure.input(sessionIdInput).mutation(async ({ input }) => { const runtime = await this.getOrCreateRuntime( input.sessionId, diff --git a/packages/chat-mastra/src/server/trpc/utils/runtime/index.ts b/packages/chat-mastra/src/server/trpc/utils/runtime/index.ts index 5a787fcab6d..26ece80e0a1 100644 --- a/packages/chat-mastra/src/server/trpc/utils/runtime/index.ts +++ b/packages/chat-mastra/src/server/trpc/utils/runtime/index.ts @@ -8,6 +8,7 @@ export { type RuntimeMcpServerStatus, type RuntimeSession, reloadHookConfig, + restartRuntimeFromUserMessage, runSessionStartHook, subscribeToSessionEvents, } from "./runtime"; diff --git a/packages/chat-mastra/src/server/trpc/utils/runtime/runtime.test.ts b/packages/chat-mastra/src/server/trpc/utils/runtime/runtime.test.ts index 8e91bf1c10e..1cbf0998d2e 100644 --- a/packages/chat-mastra/src/server/trpc/utils/runtime/runtime.test.ts +++ b/packages/chat-mastra/src/server/trpc/utils/runtime/runtime.test.ts @@ -2,6 +2,7 @@ import { describe, expect, it } from "bun:test"; import { generateAndSetTitle, type RuntimeSession, + restartRuntimeFromUserMessage, subscribeToSessionEvents, } from "./runtime"; @@ -238,3 +239,98 @@ describe("runtime title generation", () => { expect(updateTitleInputs).toEqual([]); }); }); + +describe("runtime message restart", () => { + it("clones the thread up to the target user message and resends from there", async () => { + const cloneThreadInputs: Array> = []; + const sendMessageInputs: Array> = []; + const switchThreadInputs: Array> = []; + const switchModelInputs: Array> = []; + + const memoryStore = { + getThreadById: async () => ({ + id: "thread-1", + resourceId: "resource-1", + title: "Existing Thread", + }), + listMessages: async () => ({ + messages: [ + { id: "user-1", role: "user" }, + { id: "assistant-1", role: "assistant" }, + { id: "user-2", role: "user" }, + { id: "assistant-2", role: "assistant" }, + ], + }), + cloneThread: async (input: Record) => { + cloneThreadInputs.push(input); + return { + thread: { + id: "thread-2", + resourceId: "resource-1", + title: "Existing Thread", + }, + }; + }, + }; + + const runtime: RuntimeSession = { + sessionId: "11111111-1111-1111-1111-111111111111", + harness: { + getCurrentThreadId: () => "thread-1", + abort: () => {}, + switchThread: async (input: Record) => { + switchThreadInputs.push(input); + }, + switchModel: async (input: Record) => { + switchModelInputs.push(input); + }, + sendMessage: async (input: Record) => { + sendMessageInputs.push(input); + }, + config: { + storage: { + getStore: async () => memoryStore, + }, + }, + } as unknown as RuntimeSession["harness"], + mcpManager: null as RuntimeSession["mcpManager"], + hookManager: null as RuntimeSession["hookManager"], + mcpManualStatuses: new Map(), + lastErrorMessage: "stale error", + pendingSandboxQuestion: null, + cwd: "/tmp", + }; + + await restartRuntimeFromUserMessage(runtime, { + messageId: "user-2", + payload: { + content: "Edited prompt", + }, + metadata: { + model: "anthropic/claude-sonnet-4", + }, + }); + + expect(cloneThreadInputs).toEqual([ + { + sourceThreadId: "thread-1", + resourceId: "resource-1", + title: "Existing Thread", + options: { + messageFilter: { + messageIds: ["user-1", "assistant-1"], + }, + }, + }, + ]); + expect(switchThreadInputs).toEqual([{ threadId: "thread-2" }]); + expect(switchModelInputs).toEqual([ + { + modelId: "anthropic/claude-sonnet-4", + scope: "thread", + }, + ]); + expect(sendMessageInputs).toEqual([{ content: "Edited prompt" }]); + expect(runtime.lastErrorMessage).toBeNull(); + }); +}); diff --git a/packages/chat-mastra/src/server/trpc/utils/runtime/runtime.ts b/packages/chat-mastra/src/server/trpc/utils/runtime/runtime.ts index ae76adfc61c..e0c55eaa3d5 100644 --- a/packages/chat-mastra/src/server/trpc/utils/runtime/runtime.ts +++ b/packages/chat-mastra/src/server/trpc/utils/runtime/runtime.ts @@ -1,6 +1,7 @@ import type { AppRouter } from "@superset/trpc"; import type { createTRPCClient } from "@trpc/client"; import type { createMastraCode } from "mastracode"; +import { generateTitleFromMessage } from "./title-generation"; export type RuntimeHarness = Awaited< ReturnType @@ -44,6 +45,78 @@ interface MessageLike { content: Array<{ type: string; text?: string }>; } +interface RuntimeRestartPayload { + messageId: string; + payload: { + content: string; + files?: Array<{ + data: string; + mediaType: string; + filename?: string; + }>; + }; + metadata?: { + model?: string; + }; +} + +interface RuntimeStoredMessage { + id: string; + role: string; +} + +interface RuntimeStoredThread { + id: string; + resourceId: string; + title?: string; +} + +interface RuntimeMemoryStore { + getThreadById(args: { + threadId: string; + }): Promise; + listMessages(args: { + threadId: string; + perPage: false; + orderBy: { field: "createdAt"; direction: "ASC" }; + }): Promise<{ messages: RuntimeStoredMessage[] }>; + cloneThread(args: { + sourceThreadId: string; + resourceId?: string; + title?: string; + options?: { + messageFilter?: { + messageIds?: string[]; + }; + }; + }): Promise<{ thread: RuntimeStoredThread }>; +} + +interface HarnessWithConfig { + config?: { + storage?: { + getStore: (domain: "memory") => Promise; + }; + }; +} + +async function getRuntimeMemoryStore( + runtime: RuntimeSession, +): Promise { + const harness = runtime.harness as unknown as HarnessWithConfig; + const storage = harness.config?.storage; + if (!storage) { + throw new Error("Mastra storage is not configured for this session"); + } + + const memoryStore = await storage.getStore("memory"); + if (!memoryStore) { + throw new Error("Mastra memory storage is unavailable for this session"); + } + + return memoryStore; +} + /** * Gate: validates user prompt against hooks before sending. * Throws if the hook blocks the message. @@ -231,6 +304,66 @@ function extractProviderMessage(error: unknown): string | null { return null; } +export async function restartRuntimeFromUserMessage( + runtime: RuntimeSession, + input: RuntimeRestartPayload, +): Promise { + const threadId = runtime.harness.getCurrentThreadId(); + if (!threadId) { + throw new Error("No active Mastra thread is available for editing"); + } + + const memoryStore = await getRuntimeMemoryStore(runtime); + const sourceThread = await memoryStore.getThreadById({ threadId }); + if (!sourceThread) { + throw new Error(`Mastra thread not found: ${threadId}`); + } + + const sourceMessages = await memoryStore.listMessages({ + threadId, + perPage: false, + orderBy: { field: "createdAt", direction: "ASC" }, + }); + const targetIndex = sourceMessages.messages.findIndex( + (message) => message.id === input.messageId, + ); + if (targetIndex === -1) { + throw new Error("The selected message is no longer available to edit"); + } + + const targetMessage = sourceMessages.messages[targetIndex]; + if (targetMessage?.role !== "user") { + throw new Error("Only user messages can be edited or resent"); + } + + const clonedThread = await memoryStore.cloneThread({ + sourceThreadId: threadId, + resourceId: sourceThread.resourceId, + title: sourceThread.title, + options: { + messageFilter: { + messageIds: sourceMessages.messages + .slice(0, targetIndex) + .map((message) => message.id), + }, + }, + }); + + runtime.harness.abort(); + await runtime.harness.switchThread({ threadId: clonedThread.thread.id }); + + const selectedModel = input.metadata?.model?.trim(); + if (selectedModel) { + await runtime.harness.switchModel({ + modelId: selectedModel, + scope: "thread", + }); + } + + runtime.lastErrorMessage = null; + await runtime.harness.sendMessage(input.payload); +} + function extractTextContent(parts: MessageLike["content"]): string { return parts .filter( @@ -309,5 +442,3 @@ export async function generateAndSetTitle( console.warn("[chat-mastra] Title generation failed:", error); } } - -import { generateTitleFromMessage } from "@superset/chat/host"; diff --git a/packages/chat-mastra/src/server/trpc/utils/runtime/title-generation.ts b/packages/chat-mastra/src/server/trpc/utils/runtime/title-generation.ts new file mode 100644 index 00000000000..94f3ab23dfe --- /dev/null +++ b/packages/chat-mastra/src/server/trpc/utils/runtime/title-generation.ts @@ -0,0 +1,71 @@ +type TitleModel = unknown; +type TitleAgent = { + generateTitleFromUserMessage: (args: { + message: string; + model?: string; + tracingContext?: Record; + }) => Promise; +}; +type TitleAgentCtor = new (options: { + id: string; + name: string; + instructions: string; + model: TitleModel; +}) => TitleAgent; + +type GenerateTitleFromMessageParams = + | { + message: string; + agent: TitleAgent; + modelId: string; + tracingContext?: Record; + } + | { + message: string; + agentModel: TitleModel; + agentId?: string; + agentName?: string; + instructions?: string; + tracingContext?: Record; + }; + +export async function generateTitleFromMessage( + params: GenerateTitleFromMessageParams, +): Promise { + const { message, tracingContext = {} } = params; + const cleanedMessage = message.trim(); + if (!cleanedMessage) { + return null; + } + + if ("agent" in params) { + const title = await params.agent.generateTitleFromUserMessage({ + message: cleanedMessage, + model: params.modelId, + tracingContext, + }); + return title?.trim() || null; + } + + const agentModuleId = "@mastra/core/agent"; + const { Agent } = (await import(agentModuleId)) as { + Agent?: TitleAgentCtor; + }; + if (!Agent) { + throw new Error("Mastra Agent constructor is unavailable"); + } + + const titleAgent = new Agent({ + id: params.agentId ?? "title-generator", + name: params.agentName ?? "Title Generator", + instructions: params.instructions ?? "You generate concise titles.", + model: params.agentModel, + }); + + const title = await titleAgent.generateTitleFromUserMessage({ + message: cleanedMessage, + tracingContext, + }); + + return title?.trim() || null; +} diff --git a/packages/chat-mastra/src/server/trpc/zod.ts b/packages/chat-mastra/src/server/trpc/zod.ts index 335af76172f..ad20b8b2ec2 100644 --- a/packages/chat-mastra/src/server/trpc/zod.ts +++ b/packages/chat-mastra/src/server/trpc/zod.ts @@ -81,6 +81,18 @@ export const sendMessageInput = z.object({ .optional(), }); +export const restartFromMessageInput = z.object({ + sessionId: z.uuid(), + cwd: z.string().optional(), + messageId: z.string().min(1), + payload: sendMessagePayloadSchema, + metadata: z + .object({ + model: z.string().optional(), + }) + .optional(), +}); + export const approvalRespondInput = z.object({ sessionId: z.uuid(), cwd: z.string().optional(), @@ -110,6 +122,7 @@ export type PlanPayloadInput = z.infer; export type DisplayStateInput = z.infer; export type ListMessagesInput = z.infer; export type SendMessageInput = z.infer; +export type RestartFromMessageInput = z.infer; export type ApprovalRespondInput = z.infer; export type QuestionRespondInput = z.infer; export type PlanRespondInput = z.infer; diff --git a/packages/chat/src/host/title-generation/title-generation.ts b/packages/chat/src/host/title-generation/title-generation.ts index 91c390fb3cf..94f3ab23dfe 100644 --- a/packages/chat/src/host/title-generation/title-generation.ts +++ b/packages/chat/src/host/title-generation/title-generation.ts @@ -1,7 +1,17 @@ -import { Agent } from "@mastra/core/agent"; - -type TitleAgent = Pick; -type TitleModel = ConstructorParameters[0]["model"]; +type TitleModel = unknown; +type TitleAgent = { + generateTitleFromUserMessage: (args: { + message: string; + model?: string; + tracingContext?: Record; + }) => Promise; +}; +type TitleAgentCtor = new (options: { + id: string; + name: string; + instructions: string; + model: TitleModel; +}) => TitleAgent; type GenerateTitleFromMessageParams = | { @@ -37,6 +47,14 @@ export async function generateTitleFromMessage( return title?.trim() || null; } + const agentModuleId = "@mastra/core/agent"; + const { Agent } = (await import(agentModuleId)) as { + Agent?: TitleAgentCtor; + }; + if (!Agent) { + throw new Error("Mastra Agent constructor is unavailable"); + } + const titleAgent = new Agent({ id: params.agentId ?? "title-generator", name: params.agentName ?? "Title Generator",