Skip to content

Commit

Permalink
refactor (ai/core): simplify createAsyncIterableStream (#4138)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Dec 18, 2024
1 parent e956eed commit ef39980
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 101 deletions.
66 changes: 38 additions & 28 deletions packages/ai/core/generate-object/output-strategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,35 +233,45 @@ const arrayOutputStrategy = <ELEMENT>(
) {
let publishedElements = 0;

return createAsyncIterableStream(originalStream, {
transform(chunk, controller) {
switch (chunk.type) {
case 'object': {
const array = chunk.object;

// publish new elements one by one:
for (; publishedElements < array.length; publishedElements++) {
controller.enqueue(array[publishedElements]);
return createAsyncIterableStream(
originalStream.pipeThrough(
new TransformStream<ObjectStreamPart<ELEMENT[]>, ELEMENT>({
transform(chunk, controller) {
switch (chunk.type) {
case 'object': {
const array = chunk.object;

// publish new elements one by one:
for (
;
publishedElements < array.length;
publishedElements++
) {
controller.enqueue(array[publishedElements]);
}

break;
}

case 'text-delta':
case 'finish':
break;

case 'error':
controller.error(chunk.error);
break;

default: {
const _exhaustiveCheck: never = chunk;
throw new Error(
`Unsupported chunk type: ${_exhaustiveCheck}`,
);
}
}

break;
}

case 'text-delta':
case 'finish':
break;

case 'error':
controller.error(chunk.error);
break;

default: {
const _exhaustiveCheck: never = chunk;
throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`);
}
}
},
});
},
}),
),
);
},
};
};
Expand Down
102 changes: 53 additions & 49 deletions packages/ai/core/generate-object/stream-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -954,28 +954,32 @@ class DefaultStreamObjectResult<PARTIAL, RESULT, ELEMENT_STREAM>
}

get partialObjectStream(): AsyncIterableStream<PARTIAL> {
return createAsyncIterableStream(this.stitchableStream.stream, {
transform(chunk, controller) {
switch (chunk.type) {
case 'object':
controller.enqueue(chunk.object);
break;

case 'text-delta':
case 'finish':
break;

case 'error':
controller.error(chunk.error);
break;

default: {
const _exhaustiveCheck: never = chunk;
throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`);
}
}
},
});
return createAsyncIterableStream(
this.stitchableStream.stream.pipeThrough(
new TransformStream<ObjectStreamPart<PARTIAL>, PARTIAL>({
transform(chunk, controller) {
switch (chunk.type) {
case 'object':
controller.enqueue(chunk.object);
break;

case 'text-delta':
case 'finish':
break;

case 'error':
controller.error(chunk.error);
break;

default: {
const _exhaustiveCheck: never = chunk;
throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`);
}
}
},
}),
),
);
}

get elementStream(): ELEMENT_STREAM {
Expand All @@ -985,36 +989,36 @@ class DefaultStreamObjectResult<PARTIAL, RESULT, ELEMENT_STREAM>
}

get textStream(): AsyncIterableStream<string> {
return createAsyncIterableStream(this.stitchableStream.stream, {
transform(chunk, controller) {
switch (chunk.type) {
case 'text-delta':
controller.enqueue(chunk.textDelta);
break;

case 'object':
case 'finish':
break;

case 'error':
controller.error(chunk.error);
break;

default: {
const _exhaustiveCheck: never = chunk;
throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`);
}
}
},
});
return createAsyncIterableStream(
this.stitchableStream.stream.pipeThrough(
new TransformStream<ObjectStreamPart<PARTIAL>, string>({
transform(chunk, controller) {
switch (chunk.type) {
case 'text-delta':
controller.enqueue(chunk.textDelta);
break;

case 'object':
case 'finish':
break;

case 'error':
controller.error(chunk.error);
break;

default: {
const _exhaustiveCheck: never = chunk;
throw new Error(`Unsupported chunk type: ${_exhaustiveCheck}`);
}
}
},
}),
),
);
}

get fullStream(): AsyncIterableStream<ObjectStreamPart<PARTIAL>> {
return createAsyncIterableStream(this.stitchableStream.stream, {
transform(chunk, controller) {
controller.enqueue(chunk);
},
});
return createAsyncIterableStream(this.stitchableStream.stream);
}

pipeTextStreamToResponse(response: ServerResponse, init?: ResponseInit) {
Expand Down
28 changes: 14 additions & 14 deletions packages/ai/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1180,23 +1180,23 @@ However, the LLM results are expected to be small enough to not cause issues.
}

get textStream(): AsyncIterableStream<string> {
return createAsyncIterableStream(this.teeStream(), {
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
controller.enqueue(chunk.textDelta);
} else if (chunk.type === 'error') {
controller.error(chunk.error);
}
},
});
return createAsyncIterableStream(
this.teeStream().pipeThrough(
new TransformStream<TextStreamPart<TOOLS>, string>({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
controller.enqueue(chunk.textDelta);
} else if (chunk.type === 'error') {
controller.error(chunk.error);
}
},
}),
),
);
}

get fullStream(): AsyncIterableStream<TextStreamPart<TOOLS>> {
return createAsyncIterableStream(this.teeStream(), {
transform(chunk, controller) {
controller.enqueue(chunk);
},
});
return createAsyncIterableStream(this.teeStream());
}

private toDataStreamInternal({
Expand Down
38 changes: 38 additions & 0 deletions packages/ai/core/util/async-iterable-stream.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import {
convertArrayToReadableStream,
convertAsyncIterableToArray,
convertReadableStreamToArray,
} from '@ai-sdk/provider-utils/test';
import { describe, expect, it } from 'vitest';
import { createAsyncIterableStream } from './async-iterable-stream';

describe('createAsyncIterableStream()', () => {
it('should read all chunks from a non-empty stream using async iteration', async () => {
const testData = ['Hello', 'World', 'Stream'];

const source = convertArrayToReadableStream(testData);
const asyncIterableStream = createAsyncIterableStream(source);

expect(await convertAsyncIterableToArray(asyncIterableStream)).toEqual(
testData,
);
});

it('should handle an empty stream gracefully', async () => {
const source = convertArrayToReadableStream<string>([]);
const asyncIterableStream = createAsyncIterableStream(source);

expect(await convertAsyncIterableToArray(asyncIterableStream)).toEqual([]);
});

it('should maintain ReadableStream functionality', async () => {
const testData = ['Hello', 'World'];

const source = convertArrayToReadableStream(testData);
const asyncIterableStream = createAsyncIterableStream(source);

expect(await convertReadableStreamToArray(asyncIterableStream)).toEqual(
testData,
);
});
});
17 changes: 7 additions & 10 deletions packages/ai/core/util/async-iterable-stream.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
export type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;

export function createAsyncIterableStream<S, T>(
source: ReadableStream<S>,
transformer: Transformer<S, T>,
export function createAsyncIterableStream<T>(
source: ReadableStream<T>,
): AsyncIterableStream<T> {
const transformedStream: any = source.pipeThrough(
new TransformStream(transformer),
);
const stream = source.pipeThrough(new TransformStream<T, T>());

transformedStream[Symbol.asyncIterator] = () => {
const reader = transformedStream.getReader();
(stream as AsyncIterableStream<T>)[Symbol.asyncIterator] = () => {
const reader = stream.getReader();
return {
async next(): Promise<IteratorResult<string>> {
async next(): Promise<IteratorResult<T>> {
const { done, value } = await reader.read();
return done ? { done: true, value: undefined } : { done: false, value };
},
};
};

return transformedStream;
return stream as AsyncIterableStream<T>;
}

0 comments on commit ef39980

Please sign in to comment.