Skip to content

Commit

Permalink
fix (ai/ui): decouple StreamData chunks from LLM stream (#1613)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored May 16, 2024
1 parent 1659aba commit a085d42
Show file tree
Hide file tree
Showing 11 changed files with 337 additions and 74 deletions.
5 changes: 5 additions & 0 deletions .changeset/quick-drinks-sort.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

fix (ai/ui): decouple StreamData chunks from LLM stream
1 change: 1 addition & 0 deletions examples/next-openai/app/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { openai } from '@ai-sdk/openai';
import { streamText } from 'ai';

export const dynamic = 'force-dynamic';
export const maxDuration = 60;

export async function POST(req: Request) {
// Extract the `messages` from the body of the request
Expand Down
8 changes: 5 additions & 3 deletions examples/next-openai/app/api/completion/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@ import { openai } from '@ai-sdk/openai';
import { StreamData, StreamingTextResponse, streamText } from 'ai';

export const dynamic = 'force-dynamic';
export const maxDuration = 60;

export async function POST(req: Request) {
// Extract the `prompt` from the body of the request
const { prompt } = await req.json();

const result = await streamText({
model: openai.completion('gpt-3.5-turbo-instruct'),
model: openai('gpt-3.5-turbo-instruct'),
maxTokens: 2000,
prompt,
});

// optional: use stream data
const data = new StreamData();

data.append({ test: 'value' });
data.append('call started');

// Convert the response into a friendly text-stream
// Convert the response to an AI data stream
const stream = result.toAIStream({
onFinal(completion) {
data.append('call completed');
data.close();
},
});
Expand Down
30 changes: 30 additions & 0 deletions examples/next-openai/app/api/use-chat-streamdata/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { openai } from '@ai-sdk/openai';
import { StreamData, StreamingTextResponse, streamText } from 'ai';

export const dynamic = 'force-dynamic';
export const maxDuration = 60;

export async function POST(req: Request) {
const { messages } = await req.json();

const result = await streamText({
model: openai('gpt-4-turbo'),
messages,
});

// optional: use stream data
const data = new StreamData();

data.append('initialized call');

return new StreamingTextResponse(
result.toAIStream({
onFinal() {
data.append('call completed');
data.close();
},
}),
{},
data,
);
}
37 changes: 37 additions & 0 deletions examples/next-openai/app/use-chat-streamdata/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
'use client';

import { Message, useChat } from 'ai/react';

export default function Chat() {
const { messages, input, handleInputChange, handleSubmit, data } = useChat({
api: '/api/use-chat-streamdata',
});

return (
<div className="flex flex-col w-full max-w-md py-24 mx-auto stretch">
{data && (
<pre className="p-4 text-sm bg-gray-100">
{JSON.stringify(data, null, 2)}
</pre>
)}

{messages?.map((m: Message) => (
<div key={m.id} className="whitespace-pre-wrap">
<strong>{`${m.role}: `}</strong>
{m.content}
<br />
<br />
</div>
))}

<form onSubmit={handleSubmit}>
<input
className="fixed bottom-0 w-full max-w-md p-2 mb-8 border border-gray-300 rounded shadow-xl"
value={input}
placeholder="Say something..."
onChange={handleInputChange}
/>
</form>
</div>
);
}
84 changes: 84 additions & 0 deletions packages/core/core/util/merge-streams.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { expect, it } from 'vitest';
import { mergeStreams } from './merge-streams';
import { convertReadableStreamToArray } from '../test/convert-readable-stream-to-array';
import { convertArrayToReadableStream } from '../test/convert-array-to-readable-stream';

it('should prioritize the first stream over the second stream', async () => {
const stream1 = convertArrayToReadableStream(['1a', '1b', '1c']);
const stream2 = convertArrayToReadableStream(['2a', '2b', '2c']);

const mergedStream = mergeStreams(stream1, stream2);

expect(await convertReadableStreamToArray(mergedStream)).toEqual([
'1a',
'1b',
'1c',
'2a',
'2b',
'2c',
]);
});

it('should return values from the 2nd stream until the 1st stream has values', async () => {
let stream1Controller: ReadableStreamDefaultController<string> | undefined;
const stream1 = new ReadableStream({
start(controller) {
stream1Controller = controller;
},
});

let stream2Controller: ReadableStreamDefaultController<string> | undefined;
const stream2 = new ReadableStream({
start(controller) {
stream2Controller = controller;
},
});

const mergedStream = mergeStreams(stream1, stream2);

const result: string[] = [];
const reader = mergedStream.getReader();

async function pull() {
const { value, done } = await reader.read();
result.push(value!);
}

stream2Controller!.enqueue('2a');
stream2Controller!.enqueue('2b');

await pull();
await pull();

stream2Controller!.enqueue('2c');
stream2Controller!.enqueue('2d'); // comes later
stream1Controller!.enqueue('1a');
stream2Controller!.enqueue('2e'); // comes later
stream1Controller!.enqueue('1b');
stream1Controller!.enqueue('1c');
stream2Controller!.enqueue('2f');

await pull();
await pull();
await pull();
await pull();
await pull();

stream1Controller!.close();
stream2Controller!.close();

await pull();
await pull();

expect(result).toEqual([
'2a',
'2b',
'2c',
'1a',
'1b',
'1c',
'2d',
'2e',
'2f',
]);
});
132 changes: 132 additions & 0 deletions packages/core/core/util/merge-streams.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/**
* Merges two readable streams into a single readable stream, emitting values
* from each stream as they become available.
*
* The first stream is prioritized over the second stream. If both streams have
* values available, the first stream's value is emitted first.
*
* @template VALUE1 - The type of values emitted by the first stream.
* @template VALUE2 - The type of values emitted by the second stream.
* @param {ReadableStream<VALUE1>} stream1 - The first readable stream.
* @param {ReadableStream<VALUE2>} stream2 - The second readable stream.
* @returns {ReadableStream<VALUE1 | VALUE2>} A new readable stream that emits values from both input streams.
*/
export function mergeStreams<VALUE1, VALUE2>(
stream1: ReadableStream<VALUE1>,
stream2: ReadableStream<VALUE2>,
): ReadableStream<VALUE1 | VALUE2> {
const reader1 = stream1.getReader();
const reader2 = stream2.getReader();

let lastRead1: Promise<ReadableStreamReadResult<VALUE1>> | undefined =
undefined;
let lastRead2: Promise<ReadableStreamReadResult<VALUE2>> | undefined =
undefined;

let stream1Done = false;
let stream2Done = false;

// only use when stream 2 is done:
async function readStream1(
controller: ReadableStreamDefaultController<VALUE1 | VALUE2>,
) {
try {
if (lastRead1 == null) {
lastRead1 = reader1.read();
}

const result = await lastRead1;
lastRead1 = undefined;

if (!result.done) {
controller.enqueue(result.value);
} else {
controller.close();
}
} catch (error) {
controller.error(error);
}
}

// only use when stream 1 is done:
async function readStream2(
controller: ReadableStreamDefaultController<VALUE1 | VALUE2>,
) {
try {
if (lastRead2 == null) {
lastRead2 = reader2.read();
}

const result = await lastRead2;
lastRead2 = undefined;

if (!result.done) {
controller.enqueue(result.value);
} else {
controller.close();
}
} catch (error) {
controller.error(error);
}
}

return new ReadableStream<VALUE1 | VALUE2>({
async pull(controller) {
try {
// stream 1 is done, we can only read from stream 2:
if (stream1Done) {
readStream2(controller);
return;
}

// stream 2 is done, we can only read from stream 1:
if (stream2Done) {
readStream1(controller);
return;
}

// pull the next value from the stream that was read last:
if (lastRead1 == null) {
lastRead1 = reader1.read();
}
if (lastRead2 == null) {
lastRead2 = reader2.read();
}

// Note on Promise.race (prioritizing stream 1 over stream 2):
// If the iterable contains one or more non-promise values and/or an already settled promise,
// then Promise.race() will settle to the first of these values found in the iterable.
const { result, reader } = await Promise.race([
lastRead1.then(result => ({ result, reader: reader1 })),
lastRead2.then(result => ({ result, reader: reader2 })),
]);

if (!result.done) {
controller.enqueue(result.value);
}

if (reader === reader1) {
lastRead1 = undefined;
if (result.done) {
// stream 1 is done, we can only read from stream 2:
readStream2(controller);
stream1Done = true;
}
} else {
lastRead2 = undefined;
// stream 2 is done, we can only read from stream 1:
if (result.done) {
stream2Done = true;
readStream1(controller);
}
}
} catch (error) {
controller.error(error);
}
},
cancel() {
reader1.cancel();
reader2.cancel();
},
});
}
2 changes: 1 addition & 1 deletion packages/core/streams/inkeep-stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ describe('InkeepStream', () => {
'0:","\n',
'0:" world"\n',
'0:"."\n',
`2:[{"onFinalMetadata":{"chat_session_id":"12345",${recordsCitedSerialized}}}]\n`,
`8:[{${recordsCitedSerialized}}]\n`,
`2:[{"onFinalMetadata":{"chat_session_id":"12345",${recordsCitedSerialized}}}]\n`,
]);
});
});
Loading

0 comments on commit a085d42

Please sign in to comment.