From 42d4de16c1210b2ff3f14baa30ce1ed23732778f Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Thu, 13 Mar 2025 08:01:03 -0600 Subject: [PATCH] [Security solution] Bedrock region fix (#214251) (cherry picked from commit cf73559e2dd5ce6e793b6b92183c43e6e09fc629) --- .../server/connector_types/bedrock/bedrock.ts | 192 +------------- .../connector_types/bedrock/utils.test.ts | 251 ++++++++++++++++++ .../server/connector_types/bedrock/utils.ts | 190 +++++++++++++ 3 files changed, 449 insertions(+), 184 deletions(-) create mode 100644 x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/utils.test.ts create mode 100644 x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/utils.ts diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/bedrock.ts index 56c5795a16504..6b983d48ba178 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -9,7 +9,7 @@ import type { ServiceParams } from '@kbn/actions-plugin/server'; import { SubActionConnector } from '@kbn/actions-plugin/server'; import aws from 'aws4'; import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'; -import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec'; +import type { SmithyMessageDecoderStream } from '@smithy/eventstream-codec'; import type { AxiosError, Method } from 'axios'; import type { IncomingMessage } from 'http'; import { PassThrough } from 'stream'; @@ -36,8 +36,6 @@ import type { InvokeAIRawActionParams, InvokeAIRawActionResponse, RunApiLatestResponse, - BedrockMessage, - BedrockToolChoice, ConverseActionParams, ConverseActionResponse, } from '../../../common/bedrock/types'; @@ -52,7 +50,13 @@ import type { StreamingResponse, } from '../../../common/bedrock/types'; import { DashboardActionParamsSchema } from '../../../common/bedrock/schema'; - +import { + extractRegionId, + formatBedrockBody, + parseContent, + tee, + usesDeprecatedArguments, +} from './utils'; interface SignedRequest { host: string; headers: Record; @@ -461,183 +465,3 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B return res; } } - -const formatBedrockBody = ({ - messages, - stopSequences, - temperature = 0, - system, - maxTokens = DEFAULT_TOKEN_LIMIT, - tools, - toolChoice, -}: { - messages: BedrockMessage[]; - stopSequences?: string[]; - temperature?: number; - maxTokens?: number; - // optional system message to be sent to the API - system?: string; - tools?: Array<{ name: string; description: string }>; - toolChoice?: BedrockToolChoice; -}) => ({ - anthropic_version: 'bedrock-2023-05-31', - ...ensureMessageFormat(messages, system), - max_tokens: maxTokens, - stop_sequences: stopSequences, - temperature, - tools, - tool_choice: toolChoice, -}); - -interface FormattedBedrockMessage { - role: string; - content: string | BedrockMessage['rawContent']; -} - -/** - * Ensures that the messages are in the correct format for the Bedrock API - * If 2 user or 2 assistant messages are sent in a row, Bedrock throws an error - * We combine the messages into a single message to avoid this error - * @param messages - */ -const ensureMessageFormat = ( - messages: BedrockMessage[], - systemPrompt?: string -): { - messages: FormattedBedrockMessage[]; - system?: string; -} => { - let system = systemPrompt ? systemPrompt : ''; - - const newMessages = messages.reduce((acc, m) => { - if (m.role === 'system') { - system = `${system.length ? `${system}\n` : ''}${m.content}`; - return acc; - } - - const messageRole = () => (['assistant', 'ai'].includes(m.role) ? 'assistant' : 'user'); - - if (m.rawContent) { - acc.push({ - role: messageRole(), - content: m.rawContent, - }); - return acc; - } - - const lastMessage = acc[acc.length - 1]; - if (lastMessage && lastMessage.role === m.role && typeof lastMessage.content === 'string') { - // Bedrock only accepts assistant and user roles. - // If 2 user or 2 assistant messages are sent in a row, combine the messages into a single message - return [ - ...acc.slice(0, -1), - { content: `${lastMessage.content}\n${m.content}`, role: m.role }, - ]; - } - - // force role outside of system to ensure it is either assistant or user - return [...acc, { content: m.content, role: messageRole() }]; - }, []); - - return system.length ? { system, messages: newMessages } : { messages: newMessages }; -}; - -function parseContent(content: Array<{ text?: string; type: string }>): string { - let parsedContent = ''; - if (content.length === 1 && content[0].type === 'text' && content[0].text) { - parsedContent = content[0].text; - } else if (content.length > 1) { - parsedContent = content.reduce((acc, { text }) => (text ? `${acc}\n${text}` : acc), ''); - } - return parsedContent; -} - -const usesDeprecatedArguments = (body: string): boolean => JSON.parse(body)?.prompt != null; - -function extractRegionId(url: string) { - const match = (url ?? '').match(/bedrock\.(.*?)\.amazonaws\./); - if (match) { - return match[1]; - } else { - // fallback to us-east-1 - return 'us-east-1'; - } -} - -/** - * Splits an async iterator into two independent async iterators which can be independently read from at different speeds. - * @param asyncIterator The async iterator returned from Bedrock to split - */ -function tee( - asyncIterator: SmithyMessageDecoderStream -): [SmithyMessageDecoderStream, SmithyMessageDecoderStream] { - // @ts-ignore options is private, but we need it to create the new streams - const streamOptions = asyncIterator.options; - - const streamLeft = new SmithyMessageDecoderStream(streamOptions); - const streamRight = new SmithyMessageDecoderStream(streamOptions); - - // Queues to store chunks for each stream - const leftQueue: T[] = []; - const rightQueue: T[] = []; - - // Promises for managing when a chunk is available - let leftPending: ((chunk: T | null) => void) | null = null; - let rightPending: ((chunk: T | null) => void) | null = null; - - const distribute = async () => { - for await (const chunk of asyncIterator) { - // Push the chunk into both queues - if (leftPending) { - leftPending(chunk); - leftPending = null; - } else { - leftQueue.push(chunk); - } - - if (rightPending) { - rightPending(chunk); - rightPending = null; - } else { - rightQueue.push(chunk); - } - } - - // Signal the end of the iterator - if (leftPending) { - leftPending(null); - } - if (rightPending) { - rightPending(null); - } - }; - - // Start distributing chunks from the iterator - distribute().catch(() => { - // swallow errors - }); - - // Helper to create an async iterator for each stream - const createIterator = ( - queue: T[], - setPending: (fn: ((chunk: T | null) => void) | null) => void - ) => { - return async function* () { - while (true) { - if (queue.length > 0) { - yield queue.shift()!; - } else { - const chunk = await new Promise((resolve) => setPending(resolve)); - if (chunk === null) break; // End of the stream - yield chunk; - } - } - }; - }; - - // Assign independent async iterators to each stream - streamLeft[Symbol.asyncIterator] = createIterator(leftQueue, (fn) => (leftPending = fn)); - streamRight[Symbol.asyncIterator] = createIterator(rightQueue, (fn) => (rightPending = fn)); - - return [streamLeft, streamRight]; -} diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/utils.test.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/utils.test.ts new file mode 100644 index 0000000000000..d5bca4bd5af6a --- /dev/null +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/utils.test.ts @@ -0,0 +1,251 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { + formatBedrockBody, + ensureMessageFormat, + parseContent, + usesDeprecatedArguments, + extractRegionId, + tee, +} from './utils'; +import type { SmithyMessageDecoderStream } from '@smithy/eventstream-codec'; + +describe('formatBedrockBody', () => { + it('formats the body with default values', () => { + const result = formatBedrockBody({ messages: [{ role: 'user', content: 'Hello' }] }); + expect(result).toMatchObject({ + anthropic_version: 'bedrock-2023-05-31', + messages: [{ role: 'user', content: 'Hello' }], + max_tokens: expect.any(Number), + }); + }); +}); + +describe('ensureMessageFormat', () => { + it('combines consecutive messages with the same role', () => { + const messages = [ + { role: 'user', content: 'Hi' }, + { role: 'user', content: 'How are you?' }, + ]; + const result = ensureMessageFormat(messages); + expect(result.messages).toEqual([{ role: 'user', content: 'Hi\nHow are you?' }]); + }); +}); + +describe('parseContent', () => { + it('parses single text content correctly', () => { + const result = parseContent([{ type: 'text', text: 'Sample text' }]); + expect(result).toBe('Sample text'); + }); + + it('parses multiple text contents with line breaks', () => { + const result = parseContent([ + { type: 'text', text: 'Line 1' }, + { type: 'text', text: 'Line 2' }, + ]); + expect(result).toBe(` +Line 1 +Line 2`); + }); +}); + +describe('usesDeprecatedArguments', () => { + it('returns true if prompt exists in body', () => { + const body = JSON.stringify({ prompt: 'Old format' }); + expect(usesDeprecatedArguments(body)).toBe(true); + }); + + it('returns false if prompt is absent', () => { + const body = JSON.stringify({ message: 'New format' }); + expect(usesDeprecatedArguments(body)).toBe(false); + }); +}); + +describe('extractRegionId', () => { + const possibleRuntimeUrls = [ + { url: 'https://bedrock-runtime.us-east-2.amazonaws.com', region: 'us-east-2' }, + { url: 'https://bedrock-runtime-fips.us-east-2.amazonaws.com', region: 'us-east-2' }, + { url: 'https://bedrock-runtime.us-east-1.amazonaws.com', region: 'us-east-1' }, + { url: 'https://bedrock-runtime-fips.us-east-1.amazonaws.com', region: 'us-east-1' }, + { url: 'https://bedrock-runtime.us-west-2.amazonaws.com', region: 'us-west-2' }, + { url: 'https://bedrock-runtime-fips.us-west-2.amazonaws.com', region: 'us-west-2' }, + { url: 'https://bedrock-runtime.ap-south-2.amazonaws.com', region: 'ap-south-2' }, + { url: 'https://bedrock-runtime.ap-south-1.amazonaws.com', region: 'ap-south-1' }, + { url: 'https://bedrock-runtime.ap-northeast-3.amazonaws.com', region: 'ap-northeast-3' }, + { url: 'https://bedrock-runtime.ap-northeast-2.amazonaws.com', region: 'ap-northeast-2' }, + { url: 'https://bedrock-runtime.ap-southeast-1.amazonaws.com', region: 'ap-southeast-1' }, + { url: 'https://bedrock-runtime.ap-southeast-2.amazonaws.com', region: 'ap-southeast-2' }, + { url: 'https://bedrock-runtime.ap-northeast-1.amazonaws.com', region: 'ap-northeast-1' }, + { url: 'https://bedrock-runtime.ca-central-1.amazonaws.com', region: 'ca-central-1' }, + { url: 'https://bedrock-runtime-fips.ca-central-1.amazonaws.com', region: 'ca-central-1' }, + { url: 'https://bedrock-runtime.eu-central-1.amazonaws.com', region: 'eu-central-1' }, + { url: 'https://bedrock-runtime.us-gov-east-1.amazonaws.com', region: 'us-gov-east-1' }, + { url: 'https://bedrock-runtime-fips.us-gov-east-1.amazonaws.com', region: 'us-gov-east-1' }, + { url: 'https://bedrock-runtime.us-gov-west-1.amazonaws.com', region: 'us-gov-west-1' }, + { url: 'https://bedrock-runtime-fips.us-gov-west-1.amazonaws.com', region: 'us-gov-west-1' }, + ]; + it.each(possibleRuntimeUrls)( + 'extracts the region correctly from a valid URL', + ({ url, region }) => { + const result = extractRegionId(url); + expect(result).toBe(region); + } + ); + + it('returns default region if no region is found', () => { + const result = extractRegionId('https://invalid.url.com'); + expect(result).toBe('us-east-1'); + }); +}); + +describe('tee', () => { + it('should split a stream into two identical streams', async () => { + const inputData = [1, 2, 3, 4, 5]; + const mockStream = new MockSmithyMessageDecoderStream(inputData, { + someOption: 'test', + }) as unknown as SmithyMessageDecoderStream; + + const [leftStream, rightStream] = tee(mockStream); + + const leftResults: number[] = []; + const rightResults: number[] = []; + + const leftPromise = (async () => { + for await (const chunk of leftStream) { + leftResults.push(chunk); + } + })(); + + const rightPromise = (async () => { + for await (const chunk of rightStream) { + rightResults.push(chunk); + } + })(); + + await Promise.all([leftPromise, rightPromise]); + + expect(leftResults).toEqual(inputData); + expect(rightResults).toEqual(inputData); + }); + + it('should handle empty streams', async () => { + const mockStream = new MockSmithyMessageDecoderStream([], { + someOption: 'test', + }) as unknown as SmithyMessageDecoderStream; + + const [leftStream, rightStream] = tee(mockStream); + + const leftResults: number[] = []; + const rightResults: number[] = []; + const leftPromise = (async () => { + for await (const chunk of leftStream) { + leftResults.push(chunk); + } + })(); + + const rightPromise = (async () => { + for await (const chunk of rightStream) { + rightResults.push(chunk); + } + })(); + + await Promise.all([leftPromise, rightPromise]); + expect(leftResults).toEqual([]); + expect(rightResults).toEqual([]); + }); + + it('should preserve stream options', () => { + const options = { someOption: 'test' }; + const mockStream = new MockSmithyMessageDecoderStream( + [], + options + ) as unknown as SmithyMessageDecoderStream; + + const [leftStream, rightStream] = tee(mockStream); + + // @ts-ignore options is private, but we need it to create the new streams + expect(leftStream.options).toEqual(options); + // @ts-ignore options is private, but we need it to create the new streams + expect(rightStream.options).toEqual(options); + }); + + it('should handle streams with a single element', async () => { + const inputData = [1]; + const mockStream = new MockSmithyMessageDecoderStream(inputData, { + someOption: 'test', + }) as unknown as SmithyMessageDecoderStream; + + const [leftStream, rightStream] = tee(mockStream); + + const leftResults: number[] = []; + const rightResults: number[] = []; + const leftPromise = (async () => { + for await (const chunk of leftStream) { + leftResults.push(chunk); + } + })(); + + const rightPromise = (async () => { + for await (const chunk of rightStream) { + rightResults.push(chunk); + } + })(); + + await Promise.all([leftPromise, rightPromise]); + expect(leftResults).toEqual(inputData); + expect(rightResults).toEqual(inputData); + }); + + it('should handle streams with many elements', async () => { + const inputData = Array.from({ length: 1000 }, (_, i) => i); + const mockStream = new MockSmithyMessageDecoderStream(inputData, { + someOption: 'test', + }) as unknown as SmithyMessageDecoderStream; + + const [leftStream, rightStream] = tee(mockStream); + + const leftResults: number[] = []; + const rightResults: number[] = []; + const leftPromise = (async () => { + for await (const chunk of leftStream) { + leftResults.push(chunk); + } + })(); + + const rightPromise = (async () => { + for await (const chunk of rightStream) { + rightResults.push(chunk); + } + })(); + + await Promise.all([leftPromise, rightPromise]); + + expect(leftResults).toEqual(inputData); + expect(rightResults).toEqual(inputData); + }); +}); + +class MockSmithyMessageDecoderStream { + private data: T[]; + private currentIndex: number; + public options: {}; + + constructor(data: T[], options?: {}) { + this.data = data; + this.currentIndex = 0; + this.options = options || {}; + } + + async *[Symbol.asyncIterator](): AsyncIterator { + while (this.currentIndex < this.data.length) { + yield this.data[this.currentIndex++]; + // Add a small delay for async behavior simulation (optional) + await new Promise((resolve) => setTimeout(resolve, 0)); + } + } +} diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/utils.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/utils.ts new file mode 100644 index 0000000000000..ee6ed2db25dd7 --- /dev/null +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/utils.ts @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { SmithyMessageDecoderStream } from '@smithy/eventstream-codec'; +import { DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants'; +import type { BedrockMessage, BedrockToolChoice } from '../../../common/bedrock/types'; + +export const formatBedrockBody = ({ + messages, + stopSequences, + temperature = 0, + system, + maxTokens = DEFAULT_TOKEN_LIMIT, + tools, + toolChoice, +}: { + messages: BedrockMessage[]; + stopSequences?: string[]; + temperature?: number; + maxTokens?: number; + // optional system message to be sent to the API + system?: string; + tools?: Array<{ name: string; description: string }>; + toolChoice?: BedrockToolChoice; +}) => ({ + anthropic_version: 'bedrock-2023-05-31', + ...ensureMessageFormat(messages, system), + max_tokens: maxTokens, + stop_sequences: stopSequences, + temperature, + tools, + tool_choice: toolChoice, +}); + +interface FormattedBedrockMessage { + role: string; + content: string | BedrockMessage['rawContent']; +} + +/** + * Ensures that the messages are in the correct format for the Bedrock API + * If 2 user or 2 assistant messages are sent in a row, Bedrock throws an error + * We combine the messages into a single message to avoid this error + * @param messages + */ +export const ensureMessageFormat = ( + messages: BedrockMessage[], + systemPrompt?: string +): { + messages: FormattedBedrockMessage[]; + system?: string; +} => { + let system = systemPrompt ? systemPrompt : ''; + + const newMessages = messages.reduce((acc, m) => { + if (m.role === 'system') { + system = `${system.length ? `${system}\n` : ''}${m.content}`; + return acc; + } + + const messageRole = () => (['assistant', 'ai'].includes(m.role) ? 'assistant' : 'user'); + + if (m.rawContent) { + acc.push({ + role: messageRole(), + content: m.rawContent, + }); + return acc; + } + + const lastMessage = acc[acc.length - 1]; + if (lastMessage && lastMessage.role === m.role && typeof lastMessage.content === 'string') { + // Bedrock only accepts assistant and user roles. + // If 2 user or 2 assistant messages are sent in a row, combine the messages into a single message + return [ + ...acc.slice(0, -1), + { content: `${lastMessage.content}\n${m.content}`, role: m.role }, + ]; + } + + // force role outside of system to ensure it is either assistant or user + return [...acc, { content: m.content, role: messageRole() }]; + }, []); + + return system.length ? { system, messages: newMessages } : { messages: newMessages }; +}; + +export function parseContent(content: Array<{ text?: string; type: string }>): string { + let parsedContent = ''; + if (content.length === 1 && content[0].type === 'text' && content[0].text) { + parsedContent = content[0].text; + } else if (content.length > 1) { + parsedContent = content.reduce((acc, { text }) => (text ? `${acc}\n${text}` : acc), ''); + } + return parsedContent; +} + +export const usesDeprecatedArguments = (body: string): boolean => JSON.parse(body)?.prompt != null; + +export function extractRegionId(url: string) { + const match = (url ?? '').match(/https:\/\/.*?\.([a-z\-0-9]+)\.amazonaws\.com/); + if (match) { + return match[1]; + } else { + // fallback to us-east-1 + return 'us-east-1'; + } +} + +/** + * Splits an async iterator into two independent async iterators which can be independently read from at different speeds. + * @param asyncIterator The async iterator returned from Bedrock to split + */ +export function tee( + asyncIterator: SmithyMessageDecoderStream +): [SmithyMessageDecoderStream, SmithyMessageDecoderStream] { + // @ts-ignore options is private, but we need it to create the new streams + const streamOptions = asyncIterator.options; + + const streamLeft = new SmithyMessageDecoderStream(streamOptions); + const streamRight = new SmithyMessageDecoderStream(streamOptions); + + // Queues to store chunks for each stream + const leftQueue: T[] = []; + const rightQueue: T[] = []; + + // Promises for managing when a chunk is available + let leftPending: ((chunk: T | null) => void) | null = null; + let rightPending: ((chunk: T | null) => void) | null = null; + + const distribute = async () => { + for await (const chunk of asyncIterator) { + // Push the chunk into both queues + if (leftPending) { + leftPending(chunk); + leftPending = null; + } else { + leftQueue.push(chunk); + } + + if (rightPending) { + rightPending(chunk); + rightPending = null; + } else { + rightQueue.push(chunk); + } + } + + // Signal the end of the iterator + if (leftPending) { + leftPending(null); + } + if (rightPending) { + rightPending(null); + } + }; + + // Start distributing chunks from the iterator + distribute().catch(() => { + // swallow errors + }); + + // Helper to create an async iterator for each stream + const createIterator = ( + queue: T[], + setPending: (fn: ((chunk: T | null) => void) | null) => void + ) => { + return async function* () { + while (true) { + if (queue.length > 0) { + yield queue.shift()!; + } else { + const chunk = await new Promise((resolve) => setPending(resolve)); + if (chunk === null) break; // End of the stream + yield chunk; + } + } + }; + }; + + // Assign independent async iterators to each stream + streamLeft[Symbol.asyncIterator] = createIterator(leftQueue, (fn) => (leftPending = fn)); + streamRight[Symbol.asyncIterator] = createIterator(rightQueue, (fn) => (rightPending = fn)); + + return [streamLeft, streamRight]; +}