diff --git a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts index b30be69b82cae..98828a598562a 100644 --- a/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts +++ b/x-pack/plugins/security_solution/public/assistant/get_comments/stream/stream_observable.ts @@ -48,6 +48,7 @@ export const getStreamObservable = ( `${API_ERROR}\n\n${JSON.parse(decoded).message}` : // all other responses are just strings (handled by subaction invokeStream) decoded; + chunks.push(content); observer.next({ chunks, diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts index 02ad7bcdec6a9..6961b5c5bc3b3 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/openai.ts @@ -234,14 +234,19 @@ export class OpenAIConnector extends SubActionConnector { * parses the proprietary OpenAI response into a string of the response text alone, * returning the response string to the stream */ -const transformToString = () => - new Transform({ +const transformToString = () => { + let lineBuffer: string = ''; + + return new Transform({ transform(chunk, encoding, callback) { const decoder = new TextDecoder(); const encoder = new TextEncoder(); - const nextChunk = decoder - .decode(chunk) - .split('\n') + const lines = decoder.decode(chunk).split('\n'); + lines[0] = lineBuffer + lines[0]; + + lineBuffer = lines.pop() || ''; + + const nextChunk = lines // every line starts with "data: ", we remove it and are left with stringified JSON or the string "[DONE]" .map((str) => str.substring(6)) // filter out empty lines and the "[DONE]" string @@ -255,3 +260,4 @@ const transformToString = () => callback(null, newChunk); }, }); +}; diff --git a/x-pack/test/alerting_api_integration/common/plugins/actions_simulators/server/bedrock_simulation.ts b/x-pack/test/alerting_api_integration/common/plugins/actions_simulators/server/bedrock_simulation.ts index bfa8c5cb0736f..29e77feb5edaf 100644 --- a/x-pack/test/alerting_api_integration/common/plugins/actions_simulators/server/bedrock_simulation.ts +++ b/x-pack/test/alerting_api_integration/common/plugins/actions_simulators/server/bedrock_simulation.ts @@ -7,6 +7,8 @@ import http from 'http'; +import { EventStreamCodec } from '@smithy/eventstream-codec'; +import { fromUtf8, toUtf8 } from '@smithy/util-utf8'; import { ProxyArgs, Simulator } from './simulator'; export class BedrockSimulator extends Simulator { @@ -27,6 +29,10 @@ export class BedrockSimulator extends Simulator { return BedrockSimulator.sendErrorResponse(response); } + if (request.url === '/model/anthropic.claude-v2/invoke-with-response-stream') { + return BedrockSimulator.sendStreamResponse(response); + } + return BedrockSimulator.sendResponse(response); } @@ -36,6 +42,14 @@ export class BedrockSimulator extends Simulator { response.end(JSON.stringify(bedrockSuccessResponse, null, 4)); } + private static sendStreamResponse(response: http.ServerResponse) { + response.statusCode = 200; + response.setHeader('Content-Type', 'application/octet-stream'); + response.setHeader('Transfer-Encoding', 'chunked'); + response.write(encodeBedrockResponse('Hello world, what a unique string!')); + response.end(); + } + private static sendErrorResponse(response: http.ServerResponse) { response.statusCode = 422; response.setHeader('Content-Type', 'application/json;charset=UTF-8'); @@ -52,3 +66,20 @@ export const bedrockFailedResponse = { message: 'Malformed input request: extraneous key [ooooo] is not permitted, please reformat your input and try again.', }; + +function encodeBedrockResponse(completion: string) { + return new EventStreamCodec(toUtf8, fromUtf8).encode({ + headers: { + ':event-type': { type: 'string', value: 'chunk' }, + ':content-type': { type: 'string', value: 'application/json' }, + ':message-type': { type: 'string', value: 'event' }, + }, + body: Uint8Array.from( + Buffer.from( + JSON.stringify({ + bytes: Buffer.from(JSON.stringify({ completion })).toString('base64'), + }) + ) + ), + }); +} diff --git a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts index 67053bef7801b..ff9bba40a228f 100644 --- a/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts +++ b/x-pack/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/bedrock.ts @@ -12,6 +12,7 @@ import { bedrockSuccessResponse, } from '@kbn/actions-simulators-plugin/server/bedrock_simulation'; import { DEFAULT_TOKEN_LIMIT } from '@kbn/stack-connectors-plugin/common/bedrock/constants'; +import { PassThrough } from 'stream'; import { FtrProviderContext } from '../../../../../common/ftr_provider_context'; import { getUrlPrefix, ObjectRemover } from '../../../../../common/lib'; @@ -31,6 +32,7 @@ export default function bedrockTest({ getService }: FtrProviderContext) { const supertest = getService('supertest'); const objectRemover = new ObjectRemover(supertest); const configService = getService('config'); + const retry = getService('retry'); const createConnector = async (apiUrl: string, spaceId?: string) => { const result = await supertest .post(`${getUrlPrefix(spaceId ?? 'default')}/api/actions/connector`) @@ -407,6 +409,43 @@ export default function bedrockTest({ getService }: FtrProviderContext) { data: { message: bedrockSuccessResponse.completion }, }); }); + + it.only('should invoke stream with assistant AI body argument formatted to bedrock expectations', async () => { + await new Promise((resolve, reject) => { + let responseBody: string = ''; + + const passThrough = new PassThrough(); + + supertest + .post(`/internal/elastic_assistant/actions/connector/${bedrockActionId}/_execute`) + .set('kbn-xsrf', 'foo') + .on('error', reject) + .send({ + params: { + subAction: 'invokeStream', + subActionParams: { + messages: [ + { + role: 'user', + content: 'Hello world', + }, + ], + }, + }, + assistantLangChain: false, + }) + .pipe(passThrough); + + passThrough.on('data', (chunk) => { + responseBody += chunk.toString(); + }); + + passThrough.on('end', () => { + expect(responseBody).to.eql('Hello world, what a unique string!'); + resolve(); + }); + }); + }); }); });