Skip to content

Commit

Permalink
feat (provider/openai): simulated streaming setting (#4132)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Dec 18, 2024
1 parent ed59825 commit 6faab13
Show file tree
Hide file tree
Showing 10 changed files with 504 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .changeset/modern-gorillas-smile.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/openai-compatible': patch
---

feat (provider/openai-compatible): simulated streaming setting
5 changes: 5 additions & 0 deletions .changeset/stupid-pigs-buy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/openai': patch
---

feat (provider/openai): simulated streaming setting
6 changes: 6 additions & 0 deletions content/providers/01-ai-sdk-providers/01-openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ The following optional settings are available for OpenAI chat models:
private models or when the images are not publicly accessible.
Defaults to `false`.

- **simulateStreaming** _boolean_

Simulates streaming by using a normal generate call and returning it as a stream.
Enable this if the model that you are using does not support streaming.
Defaults to `false`.

#### Structured Outputs

You can enable [OpenAI structured outputs](https://openai.com/index/introducing-structured-outputs-in-the-api/) by setting the `structuredOutputs` option to `true`.
Expand Down
20 changes: 20 additions & 0 deletions examples/ai-core/src/stream-text/openai-simulated-streaming.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { openai } from '@ai-sdk/openai';
import { streamText } from 'ai';
import 'dotenv/config';

async function main() {
const result = streamText({
model: openai('o1-preview', { simulateStreaming: true }),
prompt: 'Invent a new holiday and describe its traditions.',
});

for await (const textPart of result.textStream) {
process.stdout.write(textPart);
}

console.log();
console.log('Token usage:', await result.usage);
console.log('Finish reason:', await result.finishReason);
}

main().catch(console.error);
Original file line number Diff line number Diff line change
Expand Up @@ -1406,3 +1406,157 @@ describe('doStream', () => {
});
});
});

describe('doStream simulated streaming', () => {
const server = new JsonTestServer('https://my.api.com/v1/chat/completions');

server.setupTestEnvironment();

function prepareJsonResponse({
content = '',
tool_calls,
usage = {
prompt_tokens: 4,
total_tokens: 34,
completion_tokens: 30,
},
finish_reason = 'stop',
id = 'chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd',
created = 1711115037,
model = 'gpt-3.5-turbo-0125',
}: {
content?: string;
tool_calls?: Array<{
id: string;
type: 'function';
function: {
name: string;
arguments: string;
};
}>;
usage?: {
prompt_tokens?: number;
total_tokens?: number;
completion_tokens?: number;
};
finish_reason?: string;
created?: number;
id?: string;
model?: string;
} = {}) {
server.responseBodyJson = {
id,
object: 'chat.completion',
created,
model,
choices: [
{
index: 0,
message: {
role: 'assistant',
content,
tool_calls,
},
finish_reason,
},
],
usage,
system_fingerprint: 'fp_3bc1b5746c',
};
}

it('should stream text delta', async () => {
prepareJsonResponse({ content: 'Hello, World!', model: 'o1-preview' });

const model = provider.chatModel('o1', {
simulateStreaming: true,
});

const { stream } = await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(await convertReadableStreamToArray(stream)).toStrictEqual([
{
type: 'response-metadata',
id: 'chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd',
modelId: 'o1-preview',
timestamp: expect.any(Date),
},
{ type: 'text-delta', textDelta: 'Hello, World!' },
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 4, completionTokens: 30 },
logprobs: undefined,
providerMetadata: undefined,
},
]);
});

it('should stream tool calls', async () => {
prepareJsonResponse({
model: 'o1-preview',
tool_calls: [
{
id: 'call_O17Uplv4lJvD6DVdIvFFeRMw',
type: 'function',
function: {
name: 'test-tool',
arguments: '{"value":"Sparkle Day"}',
},
},
],
});

const model = provider.chatModel('o1', {
simulateStreaming: true,
});

const { stream } = await model.doStream({
inputFormat: 'prompt',
mode: {
type: 'regular',
tools: [
{
type: 'function',
name: 'test-tool',
parameters: {
type: 'object',
properties: { value: { type: 'string' } },
required: ['value'],
additionalProperties: false,
$schema: 'http://json-schema.org/draft-07/schema#',
},
},
],
},
prompt: TEST_PROMPT,
});

expect(await convertReadableStreamToArray(stream)).toStrictEqual([
{
type: 'response-metadata',
id: 'chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd',
modelId: 'o1-preview',
timestamp: expect.any(Date),
},
{
type: 'tool-call',
toolCallId: 'call_O17Uplv4lJvD6DVdIvFFeRMw',
toolCallType: 'function',
toolName: 'test-tool',
args: '{"value":"Sparkle Day"}',
},
{
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 4, completionTokens: 30 },
logprobs: undefined,
providerMetadata: undefined,
},
]);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,43 @@ export class OpenAICompatibleChatLanguageModel implements LanguageModelV1 {
async doStream(
options: Parameters<LanguageModelV1['doStream']>[0],
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
if (this.settings.simulateStreaming) {
const result = await this.doGenerate(options);
const simulatedStream = new ReadableStream<LanguageModelV1StreamPart>({
start(controller) {
controller.enqueue({ type: 'response-metadata', ...result.response });
if (result.text) {
controller.enqueue({
type: 'text-delta',
textDelta: result.text,
});
}
if (result.toolCalls) {
for (const toolCall of result.toolCalls) {
controller.enqueue({
type: 'tool-call',
...toolCall,
});
}
}
controller.enqueue({
type: 'finish',
finishReason: result.finishReason,
usage: result.usage,
logprobs: result.logprobs,
providerMetadata: result.providerMetadata,
});
controller.close();
},
});
return {
stream: simulatedStream,
rawCall: result.rawCall,
rawResponse: result.rawResponse,
warnings: result.warnings,
};
}

const { args, warnings } = this.getArgs({ ...options });

const body = JSON.stringify({ ...args, stream: true });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,12 @@ A unique identifier representing your end-user, which can help the provider to
monitor and detect abuse.
*/
user?: string;

/**
Simulates streaming by using a normal generate call and returning it as a stream.
Enable this if the model that you are using does not support streaming.

Defaults to `false`.
*/
simulateStreaming?: boolean;
}
Loading

0 comments on commit 6faab13

Please sign in to comment.