-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix (ai/ui): decouple StreamData chunks from LLM stream (#1613)
- Loading branch information
Showing
11 changed files
with
337 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
'ai': patch | ||
--- | ||
|
||
fix (ai/ui): decouple StreamData chunks from LLM stream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import { openai } from '@ai-sdk/openai'; | ||
import { StreamData, StreamingTextResponse, streamText } from 'ai'; | ||
|
||
export const dynamic = 'force-dynamic'; | ||
export const maxDuration = 60; | ||
|
||
export async function POST(req: Request) { | ||
const { messages } = await req.json(); | ||
|
||
const result = await streamText({ | ||
model: openai('gpt-4-turbo'), | ||
messages, | ||
}); | ||
|
||
// optional: use stream data | ||
const data = new StreamData(); | ||
|
||
data.append('initialized call'); | ||
|
||
return new StreamingTextResponse( | ||
result.toAIStream({ | ||
onFinal() { | ||
data.append('call completed'); | ||
data.close(); | ||
}, | ||
}), | ||
{}, | ||
data, | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
'use client'; | ||
|
||
import { Message, useChat } from 'ai/react'; | ||
|
||
export default function Chat() { | ||
const { messages, input, handleInputChange, handleSubmit, data } = useChat({ | ||
api: '/api/use-chat-streamdata', | ||
}); | ||
|
||
return ( | ||
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch"> | ||
{data && ( | ||
<pre className="p-4 text-sm bg-gray-100"> | ||
{JSON.stringify(data, null, 2)} | ||
</pre> | ||
)} | ||
|
||
{messages?.map((m: Message) => ( | ||
<div key={m.id} className="whitespace-pre-wrap"> | ||
<strong>{`${m.role}: `}</strong> | ||
{m.content} | ||
<br /> | ||
<br /> | ||
</div> | ||
))} | ||
|
||
<form onSubmit={handleSubmit}> | ||
<input | ||
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl" | ||
value={input} | ||
placeholder="Say something..." | ||
onChange={handleInputChange} | ||
/> | ||
</form> | ||
</div> | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import { expect, it } from 'vitest'; | ||
import { mergeStreams } from './merge-streams'; | ||
import { convertReadableStreamToArray } from '../test/convert-readable-stream-to-array'; | ||
import { convertArrayToReadableStream } from '../test/convert-array-to-readable-stream'; | ||
|
||
it('should prioritize the first stream over the second stream', async () => { | ||
const stream1 = convertArrayToReadableStream(['1a', '1b', '1c']); | ||
const stream2 = convertArrayToReadableStream(['2a', '2b', '2c']); | ||
|
||
const mergedStream = mergeStreams(stream1, stream2); | ||
|
||
expect(await convertReadableStreamToArray(mergedStream)).toEqual([ | ||
'1a', | ||
'1b', | ||
'1c', | ||
'2a', | ||
'2b', | ||
'2c', | ||
]); | ||
}); | ||
|
||
it('should return values from the 2nd stream until the 1st stream has values', async () => { | ||
let stream1Controller: ReadableStreamDefaultController<string> | undefined; | ||
const stream1 = new ReadableStream({ | ||
start(controller) { | ||
stream1Controller = controller; | ||
}, | ||
}); | ||
|
||
let stream2Controller: ReadableStreamDefaultController<string> | undefined; | ||
const stream2 = new ReadableStream({ | ||
start(controller) { | ||
stream2Controller = controller; | ||
}, | ||
}); | ||
|
||
const mergedStream = mergeStreams(stream1, stream2); | ||
|
||
const result: string[] = []; | ||
const reader = mergedStream.getReader(); | ||
|
||
async function pull() { | ||
const { value, done } = await reader.read(); | ||
result.push(value!); | ||
} | ||
|
||
stream2Controller!.enqueue('2a'); | ||
stream2Controller!.enqueue('2b'); | ||
|
||
await pull(); | ||
await pull(); | ||
|
||
stream2Controller!.enqueue('2c'); | ||
stream2Controller!.enqueue('2d'); // comes later | ||
stream1Controller!.enqueue('1a'); | ||
stream2Controller!.enqueue('2e'); // comes later | ||
stream1Controller!.enqueue('1b'); | ||
stream1Controller!.enqueue('1c'); | ||
stream2Controller!.enqueue('2f'); | ||
|
||
await pull(); | ||
await pull(); | ||
await pull(); | ||
await pull(); | ||
await pull(); | ||
|
||
stream1Controller!.close(); | ||
stream2Controller!.close(); | ||
|
||
await pull(); | ||
await pull(); | ||
|
||
expect(result).toEqual([ | ||
'2a', | ||
'2b', | ||
'2c', | ||
'1a', | ||
'1b', | ||
'1c', | ||
'2d', | ||
'2e', | ||
'2f', | ||
]); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/** | ||
* Merges two readable streams into a single readable stream, emitting values | ||
* from each stream as they become available. | ||
* | ||
* The first stream is prioritized over the second stream. If both streams have | ||
* values available, the first stream's value is emitted first. | ||
* | ||
* @template VALUE1 - The type of values emitted by the first stream. | ||
* @template VALUE2 - The type of values emitted by the second stream. | ||
* @param {ReadableStream<VALUE1>} stream1 - The first readable stream. | ||
* @param {ReadableStream<VALUE2>} stream2 - The second readable stream. | ||
* @returns {ReadableStream<VALUE1 | VALUE2>} A new readable stream that emits values from both input streams. | ||
*/ | ||
export function mergeStreams<VALUE1, VALUE2>( | ||
stream1: ReadableStream<VALUE1>, | ||
stream2: ReadableStream<VALUE2>, | ||
): ReadableStream<VALUE1 | VALUE2> { | ||
const reader1 = stream1.getReader(); | ||
const reader2 = stream2.getReader(); | ||
|
||
let lastRead1: Promise<ReadableStreamReadResult<VALUE1>> | undefined = | ||
undefined; | ||
let lastRead2: Promise<ReadableStreamReadResult<VALUE2>> | undefined = | ||
undefined; | ||
|
||
let stream1Done = false; | ||
let stream2Done = false; | ||
|
||
// only use when stream 2 is done: | ||
async function readStream1( | ||
controller: ReadableStreamDefaultController<VALUE1 | VALUE2>, | ||
) { | ||
try { | ||
if (lastRead1 == null) { | ||
lastRead1 = reader1.read(); | ||
} | ||
|
||
const result = await lastRead1; | ||
lastRead1 = undefined; | ||
|
||
if (!result.done) { | ||
controller.enqueue(result.value); | ||
} else { | ||
controller.close(); | ||
} | ||
} catch (error) { | ||
controller.error(error); | ||
} | ||
} | ||
|
||
// only use when stream 1 is done: | ||
async function readStream2( | ||
controller: ReadableStreamDefaultController<VALUE1 | VALUE2>, | ||
) { | ||
try { | ||
if (lastRead2 == null) { | ||
lastRead2 = reader2.read(); | ||
} | ||
|
||
const result = await lastRead2; | ||
lastRead2 = undefined; | ||
|
||
if (!result.done) { | ||
controller.enqueue(result.value); | ||
} else { | ||
controller.close(); | ||
} | ||
} catch (error) { | ||
controller.error(error); | ||
} | ||
} | ||
|
||
return new ReadableStream<VALUE1 | VALUE2>({ | ||
async pull(controller) { | ||
try { | ||
// stream 1 is done, we can only read from stream 2: | ||
if (stream1Done) { | ||
readStream2(controller); | ||
return; | ||
} | ||
|
||
// stream 2 is done, we can only read from stream 1: | ||
if (stream2Done) { | ||
readStream1(controller); | ||
return; | ||
} | ||
|
||
// pull the next value from the stream that was read last: | ||
if (lastRead1 == null) { | ||
lastRead1 = reader1.read(); | ||
} | ||
if (lastRead2 == null) { | ||
lastRead2 = reader2.read(); | ||
} | ||
|
||
// Note on Promise.race (prioritizing stream 1 over stream 2): | ||
// If the iterable contains one or more non-promise values and/or an already settled promise, | ||
// then Promise.race() will settle to the first of these values found in the iterable. | ||
const { result, reader } = await Promise.race([ | ||
lastRead1.then(result => ({ result, reader: reader1 })), | ||
lastRead2.then(result => ({ result, reader: reader2 })), | ||
]); | ||
|
||
if (!result.done) { | ||
controller.enqueue(result.value); | ||
} | ||
|
||
if (reader === reader1) { | ||
lastRead1 = undefined; | ||
if (result.done) { | ||
// stream 1 is done, we can only read from stream 2: | ||
readStream2(controller); | ||
stream1Done = true; | ||
} | ||
} else { | ||
lastRead2 = undefined; | ||
// stream 2 is done, we can only read from stream 1: | ||
if (result.done) { | ||
stream2Done = true; | ||
readStream1(controller); | ||
} | ||
} | ||
} catch (error) { | ||
controller.error(error); | ||
} | ||
}, | ||
cancel() { | ||
reader1.cancel(); | ||
reader2.cancel(); | ||
}, | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.