From ceb44bcc5962410b848f947f1401b25dddee3a8b Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Fri, 10 May 2024 11:29:20 +0200 Subject: [PATCH] feat (ai/ui): add stop() helper to useAssistant (#1524) --- .changeset/light-chairs-clap.md | 5 +++ .../next-openai/app/api/assistant/route.ts | 31 ++++++++++------ examples/next-openai/app/assistant/page.tsx | 20 ++++++++-- examples/next-openai/package.json | 2 +- packages/core/package.json | 6 ++- packages/core/react/use-assistant.ts | 37 +++++++++++++++++-- packages/core/streams/assistant-response.ts | 2 +- pnpm-lock.yaml | 16 ++++---- 8 files changed, 90 insertions(+), 29 deletions(-) create mode 100644 .changeset/light-chairs-clap.md diff --git a/.changeset/light-chairs-clap.md b/.changeset/light-chairs-clap.md new file mode 100644 index 000000000000..5435f89f278a --- /dev/null +++ b/.changeset/light-chairs-clap.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +feat (ai/ui): add stop() helper to useAssistant (important: AssistantResponse now requires OpenAI SDK 4.42+) diff --git a/examples/next-openai/app/api/assistant/route.ts b/examples/next-openai/app/api/assistant/route.ts index 7ba23e8cb76d..5128f84b5e02 100644 --- a/examples/next-openai/app/api/assistant/route.ts +++ b/examples/next-openai/app/api/assistant/route.ts @@ -27,22 +27,30 @@ export async function POST(req: Request) { const threadId = input.threadId ?? (await openai.beta.threads.create({})).id; // Add a message to the thread - const createdMessage = await openai.beta.threads.messages.create(threadId, { - role: 'user', - content: input.message, - }); + const createdMessage = await openai.beta.threads.messages.create( + threadId, + { + role: 'user', + content: input.message, + }, + { signal: req.signal }, + ); return AssistantResponse( { threadId, messageId: createdMessage.id }, async ({ forwardStream, sendDataMessage }) => { // Run the assistant on the thread - const runStream = openai.beta.threads.runs.createAndStream(threadId, { - assistant_id: - process.env.ASSISTANT_ID ?? - (() => { - throw new Error('ASSISTANT_ID is not set'); - })(), - }); + const runStream = openai.beta.threads.runs.stream( + threadId, + { + assistant_id: + process.env.ASSISTANT_ID ?? + (() => { + throw new Error('ASSISTANT_ID is not set'); + })(), + }, + { signal: req.signal }, + ); // forward run status would stream message deltas let runResult = await forwardStream(runStream); @@ -108,6 +116,7 @@ export async function POST(req: Request) { threadId, runResult.id, { tool_outputs }, + { signal: req.signal }, ), ); } diff --git a/examples/next-openai/app/assistant/page.tsx b/examples/next-openai/app/assistant/page.tsx index 1bfb87989e79..a8ae32187650 100644 --- a/examples/next-openai/app/assistant/page.tsx +++ b/examples/next-openai/app/assistant/page.tsx @@ -13,8 +13,15 @@ const roleToColorMap: Record = { }; export default function Chat() { - const { status, messages, input, submitMessage, handleInputChange, error } = - useAssistant({ api: '/api/assistant' }); + const { + status, + messages, + input, + submitMessage, + handleInputChange, + error, + stop, + } = useAssistant({ api: '/api/assistant' }); // When status changes to accepting messages, focus the input: const inputRef = useRef(null); @@ -64,12 +71,19 @@ export default function Chat() { + + ); } diff --git a/examples/next-openai/package.json b/examples/next-openai/package.json index 14141680de2b..0816759b3e4f 100644 --- a/examples/next-openai/package.json +++ b/examples/next-openai/package.json @@ -12,7 +12,7 @@ "@ai-sdk/openai": "latest", "ai": "latest", "next": "latest", - "openai": "4.29.0", + "openai": "4.42.0", "react": "18.2.0", "react-dom": "^18.2.0", "zod": "3.23.4" diff --git a/packages/core/package.json b/packages/core/package.json index 786c6af0dfdb..cf72ad8b8fe4 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -117,7 +117,7 @@ "jsdom": "^23.0.0", "langchain": "0.0.196", "msw": "2.0.9", - "openai": "4.29.0", + "openai": "4.42.0", "react-dom": "^18.2.0", "react-server-dom-webpack": "18.3.0-canary-eb33bd747-20240312", "solid-js": "^1.8.7", @@ -127,6 +127,7 @@ "zod": "3.22.4" }, "peerDependencies": { + "openai": "^4.42.0", "react": "^18.2.0", "solid-js": "^1.7.7", "svelte": "^3.0.0 || ^4.0.0", @@ -148,6 +149,9 @@ }, "zod": { "optional": true + }, + "openai": { + "optional": true } }, "engines": { diff --git a/packages/core/react/use-assistant.ts b/packages/core/react/use-assistant.ts index b9baa4c3362b..3d4d8f215b44 100644 --- a/packages/core/react/use-assistant.ts +++ b/packages/core/react/use-assistant.ts @@ -1,10 +1,11 @@ /* eslint-disable react-hooks/rules-of-hooks */ -import { useState } from 'react'; - +import { isAbortError } from '@ai-sdk/provider-utils'; +import { useCallback, useRef, useState } from 'react'; import { generateId } from '../shared/generate-id'; import { readDataStream } from '../shared/read-data-stream'; import { CreateMessage, Message } from '../shared/types'; +import { abort } from 'node:process'; export type AssistantStatus = 'in_progress' | 'awaiting_message'; @@ -42,6 +43,11 @@ export type UseAssistantHelpers = { }, ) => Promise; + /** +Abort the current request immediately, keep the generated tokens if any. + */ + stop: () => void; + /** * setState-powered method to update the input value. */ @@ -135,6 +141,16 @@ export function useAssistant({ setInput(event.target.value); }; + // Abort controller to cancel the current API call. + const abortControllerRef = useRef(null); + + const stop = useCallback(() => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + } + }, []); + const append = async ( message: Message | CreateMessage, requestOptions?: { @@ -153,10 +169,15 @@ export function useAssistant({ setInput(''); + const abortController = new AbortController(); + try { + abortControllerRef.current = abortController; + const result = await fetch(api, { method: 'POST', credentials, + signal: abortController.signal, headers: { 'Content-Type': 'application/json', ...headers }, body: JSON.stringify({ ...body, @@ -240,14 +261,21 @@ export function useAssistant({ } } } catch (error) { + // Ignore abort errors as they are expected when the user cancels the request: + if (isAbortError(error) && abortController.signal.aborted) { + abortControllerRef.current = null; + return; + } + if (onError && error instanceof Error) { onError(error); } setError(error as Error); + } finally { + abortControllerRef.current = null; + setStatus('awaiting_message'); } - - setStatus('awaiting_message'); }; const submitMessage = async ( @@ -276,6 +304,7 @@ export function useAssistant({ submitMessage, status, error, + stop, }; } diff --git a/packages/core/streams/assistant-response.ts b/packages/core/streams/assistant-response.ts index ead8f3a942bd..f9f984720b0f 100644 --- a/packages/core/streams/assistant-response.ts +++ b/packages/core/streams/assistant-response.ts @@ -1,4 +1,4 @@ -import { AssistantStream } from 'openai/lib/AssistantStream'; +import { type AssistantStream } from 'openai/lib/AssistantStream'; import { Run } from 'openai/resources/beta/threads/runs/runs'; import { formatStreamPart } from '../shared/stream-parts'; import { AssistantMessage, DataMessage } from '../shared/types'; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 6aaffb1cb5b5..0ec45bfd982b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -663,8 +663,8 @@ importers: specifier: latest version: 14.2.3(react-dom@18.2.0)(react@18.2.0) openai: - specifier: 4.29.0 - version: 4.29.0 + specifier: 4.42.0 + version: 4.42.0 react: specifier: 18.2.0 version: 18.2.0 @@ -1295,8 +1295,8 @@ importers: specifier: 2.0.9 version: 2.0.9(typescript@5.1.3) openai: - specifier: 4.29.0 - version: 4.29.0 + specifier: 4.42.0 + version: 4.42.0 react-dom: specifier: ^18.2.0 version: 18.2.0(react@18.2.0) @@ -12836,7 +12836,7 @@ packages: langchainhub: 0.0.6 langsmith: 0.0.48 ml-distance: 4.0.1 - openai: 4.28.4 + openai: 4.42.0 openapi-types: 12.1.3 p-retry: 4.6.2 uuid: 9.0.1 @@ -14398,16 +14398,16 @@ packages: web-streams-polyfill: 3.2.1 transitivePeerDependencies: - encoding + dev: false - /openai@4.29.0: - resolution: {integrity: sha512-ic6C681bSow1XQdKhADthM/OOKqNL05M1gCFLx1mRqLJ+yH49v6qnvaWQ76kwqI/IieCuVTXfRfTk3sz4cB45w==} + /openai@4.42.0: + resolution: {integrity: sha512-xbiQQ2YNqdkE6cHqeWKa7lsAvdYfgp84XiNFOVkAMa6+9KpmOL4hCWCRR6e6I/clpaens/T9XeLVtyC5StXoRw==} hasBin: true dependencies: '@types/node': 18.18.9 '@types/node-fetch': 2.6.9 abort-controller: 3.0.0 agentkeepalive: 4.5.0 - digest-fetch: 1.3.0 form-data-encoder: 1.7.2 formdata-node: 4.4.1 node-fetch: 2.7.0