diff --git a/packages/app/src/i18n/en.ts b/packages/app/src/i18n/en.ts index 97a572f1cf2..c87e7cb9dbb 100644 --- a/packages/app/src/i18n/en.ts +++ b/packages/app/src/i18n/en.ts @@ -530,6 +530,11 @@ export const dict = { "session.todo.title": "Todos", "session.todo.collapse": "Collapse", "session.todo.expand": "Expand", + "session.revertDock.summary.one": "{{count}} rolled back message", + "session.revertDock.summary.other": "{{count}} rolled back messages", + "session.revertDock.collapse": "Collapse rolled back messages", + "session.revertDock.expand": "Expand rolled back messages", + "session.revertDock.restore": "Restore message", "session.new.title": "Build anything", "session.new.worktree.main": "Main branch", diff --git a/packages/app/src/pages/session.tsx b/packages/app/src/pages/session.tsx index a5c7bf90b32..1b62b94294c 100644 --- a/packages/app/src/pages/session.tsx +++ b/packages/app/src/pages/session.tsx @@ -43,6 +43,7 @@ import { SessionSidePanel } from "@/pages/session/session-side-panel" import { TerminalPanel } from "@/pages/session/terminal-panel" import { useSessionCommands } from "@/pages/session/use-session-commands" import { useSessionHashScroll } from "@/pages/session/use-session-hash-scroll" +import { extractPromptFromParts } from "@/utils/prompt" import { same } from "@/utils/same" import { formatServerError } from "@/utils/server-errors" @@ -286,6 +287,7 @@ export default function Page() { const [ui, setUi] = createStore({ git: false, pendingMessage: undefined as string | undefined, + restoring: undefined as string | undefined, reviewSnap: false, scrollGesture: 0, scroll: { @@ -1179,6 +1181,110 @@ export default function Page() { scroller: () => scroller, }) + const draft = (id: string) => + extractPromptFromParts(sync.data.part[id] ?? [], { + directory: sdk.directory, + attachmentName: language.t("common.attachment"), + }) + + const line = (id: string) => { + const text = draft(id) + .map((part) => (part.type === "image" ? `[image:${part.filename}]` : part.content)) + .join("") + .replace(/\s+/g, " ") + .trim() + if (text) return text + return `[${language.t("common.attachment")}]` + } + + const fail = (err: unknown) => { + showToast({ + variant: "error", + title: language.t("common.requestFailed"), + description: formatServerError(err, language.t), + }) + } + + const busy = (sessionID: string) => { + if (sync.data.session_status[sessionID]?.type !== "idle") return true + return (sync.data.message[sessionID] ?? []).some( + (item) => item.role === "assistant" && typeof item.time.completed !== "number", + ) + } + + const halt = (sessionID: string) => + busy(sessionID) ? sdk.client.session.abort({ sessionID }).catch(() => {}) : Promise.resolve() + + const fork = (input: { sessionID: string; messageID: string }) => { + const value = draft(input.messageID) + return sdk.client.session + .fork(input) + .then((result) => { + const next = result.data + if (!next) { + showToast({ + variant: "error", + title: language.t("common.requestFailed"), + }) + return + } + navigate(`/${base64Encode(sdk.directory)}/session/${next.id}`) + requestAnimationFrame(() => { + prompt.set(value) + }) + }) + .catch(fail) + } + + const revert = (input: { sessionID: string; messageID: string }) => { + const value = draft(input.messageID) + return halt(input.sessionID) + .then(() => sdk.client.session.revert(input)) + .then(() => { + prompt.set(value) + }) + .catch(fail) + } + + const restore = (id: string) => { + const sessionID = params.id + if (!sessionID || ui.restoring) return + + const next = userMessages().find((item) => item.id > id) + setUi("restoring", id) + + const task = !next + ? halt(sessionID) + .then(() => sdk.client.session.unrevert({ sessionID })) + .then(() => { + prompt.reset() + }) + : halt(sessionID) + .then(() => + sdk.client.session.revert({ + sessionID, + messageID: next.id, + }), + ) + .then(() => { + prompt.set(draft(next.id)) + }) + + return task.catch(fail).finally(() => { + setUi("restoring", (value) => (value === id ? undefined : value)) + }) + } + + const rolled = createMemo(() => { + const id = revertMessageID() + if (!id) return [] + return userMessages() + .filter((item) => item.id >= id) + .map((item) => ({ id: item.id, text: line(item.id) })) + }) + + const actions = { fork, revert } + createResizeObserver( () => promptDock, ({ height }) => { @@ -1268,6 +1374,7 @@ export default function Page() { loadingClass: "px-4 py-4 text-text-weak", emptyClass: "h-full pb-64 -mt-4 flex flex-col items-center justify-center text-center gap-6", })} + actions={actions} scroll={ui.scroll} onResumeScroll={resumeScroll} setScrollRef={setScrollRef} @@ -1333,6 +1440,15 @@ export default function Page() { resumeScroll() }} onResponseSubmit={resumeScroll} + revert={ + rolled().length > 0 + ? { + items: rolled(), + restoring: ui.restoring, + onRestore: restore, + } + : undefined + } setPromptDockRef={(el) => { promptDock = el }} diff --git a/packages/app/src/pages/session/composer/session-composer-region.tsx b/packages/app/src/pages/session/composer/session-composer-region.tsx index 93ea3d465c5..08746b51a56 100644 --- a/packages/app/src/pages/session/composer/session-composer-region.tsx +++ b/packages/app/src/pages/session/composer/session-composer-region.tsx @@ -8,6 +8,7 @@ import { usePrompt } from "@/context/prompt" import { getSessionHandoff, setSessionHandoff } from "@/pages/session/handoff" import { SessionPermissionDock } from "@/pages/session/composer/session-permission-dock" import { SessionQuestionDock } from "@/pages/session/composer/session-question-dock" +import { SessionRevertDock } from "@/pages/session/composer/session-revert-dock" import type { SessionComposerState } from "@/pages/session/composer/session-composer-state" import { SessionTodoDock } from "@/pages/session/composer/session-todo-dock" @@ -20,6 +21,11 @@ export function SessionComposerRegion(props: { onNewSessionWorktreeReset: () => void onSubmit: () => void onResponseSubmit: () => void + revert?: { + items: { id: string; text: string }[] + restoring?: string + onRestore: (id: string) => void + } setPromptDockRef: (el: HTMLDivElement) => void visualDuration?: number bounce?: number @@ -116,6 +122,8 @@ export function SessionComposerRegion(props: { const value = createMemo(() => Math.max(0, Math.min(1, progress()))) const [height, setHeight] = createSignal(320) const dock = createMemo(() => (gate.ready && props.state.dock()) || value() > 0.001) + const rolled = createMemo(() => (props.revert?.items.length ? props.revert : undefined)) + const lift = createMemo(() => (rolled() ? 18 : 36 * value())) const full = createMemo(() => Math.max(78, height())) const [contentRef, setContentRef] = createSignal() @@ -170,9 +178,22 @@ export function SessionComposerRegion(props: { - {handoffPrompt() || language.t("prompt.loading")} - + <> + + {(revert) => ( +
+ +
+ )} +
+
+ {handoffPrompt() || language.t("prompt.loading")} +
+ } > @@ -209,12 +230,23 @@ export function SessionComposerRegion(props: { + + {(revert) => ( +
+ +
+ )} +
void +}) { + const language = useLanguage() + const [store, setStore] = createStore({ + collapsed: false, + }) + + const toggle = () => setStore("collapsed", (value) => !value) + const total = createMemo(() => props.items.length) + const label = createMemo(() => + language.t(total() === 1 ? "session.revertDock.summary.one" : "session.revertDock.summary.other", { + count: total(), + }), + ) + const preview = createMemo(() => props.items[0]?.text ?? "") + + return ( + +
{ + if (event.key !== "Enter" && event.key !== " ") return + event.preventDefault() + toggle() + }} + > + {label()} + + {preview()} + +
+ { + event.preventDefault() + event.stopPropagation() + }} + onClick={(event) => { + event.stopPropagation() + toggle() + }} + aria-label={ + store.collapsed ? language.t("session.revertDock.expand") : language.t("session.revertDock.collapse") + } + /> +
+
+ + +