From c836b2a801940adee0bc326dd7c8366a5b3859cc Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 1 Jul 2025 23:19:25 +0200 Subject: [PATCH 01/16] [Obs AI Assistant] Collapse *query tool calls --- .../bedrock/bedrock_claude_adapter.ts | 2 +- .../common/convert_messages_for_inference.ts | 41 ++++++++- .../server/service/client/index.ts | 4 + .../client/operators/continue_conversation.ts | 87 ++++++++++++++++--- .../server/functions/query/index.ts | 27 ++---- 5 files changed, 128 insertions(+), 33 deletions(-) diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts index 8ad322c6351bb..7b8f1adf3991c 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts @@ -61,7 +61,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = { const subActionParams = { system: systemMessage, messages: converseMessages, - tools: bedRockTools, + tools: bedRockTools?.length ? bedRockTools : undefined, toolChoice: toolChoiceToConverse(toolChoice), temperature, model: modelName, diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts index 297a19330b5a4..4d1f77554a235 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts @@ -12,6 +12,7 @@ import { } from '@kbn/inference-common'; import { generateFakeToolCallId } from '@kbn/inference-plugin/common'; import type { Logger } from '@kbn/logging'; +import { takeWhile } from 'lodash'; import { Message, MessageRole } from '.'; function safeJsonParse(jsonString: string | undefined, logger: Pick) { @@ -26,13 +27,51 @@ function safeJsonParse(jsonString: string | undefined, logger: Pick) { + const collapsed: Message[] = []; + + for (let i = 0; i < messages.length; i++) { + const message = messages[i]; + + if (message.message.role === MessageRole.User && message.message.name === 'query') { + const messagesToCollapse = takeWhile(messages.slice(i + 1), (msg) => { + const name = msg.message.name || msg.message.function_call?.name; + return !name || ['query', 'visualize_query', 'execute_query'].includes(name); + }); + + if (messagesToCollapse.length) { + const content = JSON.parse(message.message.content!); + collapsed.push({ + ...message, + message: { + ...message.message, + content: JSON.stringify({ + ...content, + steps: convertMessagesForInference(messagesToCollapse, logger), + }), + }, + }); + + i += messagesToCollapse.length; + continue; + } + } + + collapsed.push(message); + } + + return collapsed; +} + export function convertMessagesForInference( messages: Message[], logger: Pick ): InferenceMessage[] { const inferenceMessages: InferenceMessage[] = []; - messages.forEach((message) => { + const collapsedMessages: Message[] = collapseMessages(messages, logger); + + collapsedMessages.forEach((message, idx) => { if (message.message.role === MessageRole.Assistant) { inferenceMessages.push({ role: InferenceMessageRole.Assistant, diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts index f7ea59f02c70a..470312d3672a7 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts @@ -509,7 +509,9 @@ export class ObservabilityAIAssistantClient { this.dependencies.inferenceClient .chatComplete({ ...options, + temperature: 0.25, stream: true, + maxRetries: 0, messages: convertMessagesForInference(redactedMessages, this.dependencies.logger), }) // unredact complete assistant response event @@ -535,6 +537,8 @@ export class ObservabilityAIAssistantClient { return this.dependencies.inferenceClient.chatComplete({ ...options, messages: convertMessagesForInference(messages, this.dependencies.logger), + temperature: 0.25, + maxRetries: 0, stream: false, }) as TStream extends true ? never : Promise; } diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts index 13493d4c25ccf..ba0d5eb393247 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts @@ -7,13 +7,14 @@ import { Logger } from '@kbn/logging'; import { decode, encode } from 'gpt-tokenizer'; -import { last, pick, take } from 'lodash'; +import { last, omit, pick, take } from 'lodash'; import { catchError, concat, EMPTY, from, isObservable, + map, Observable, of, OperatorFunction, @@ -23,7 +24,14 @@ import { } from 'rxjs'; import { withExecuteToolSpan } from '@kbn/inference-tracing'; import { CONTEXT_FUNCTION_NAME } from '../../../functions/context/context'; -import { createFunctionNotFoundError, Message, MessageRole } from '../../../../common'; +import { + CompatibleJSONSchema, + createFunctionNotFoundError, + Message, + MessageAddEvent, + MessageRole, + StreamingChatResponseEventType, +} from '../../../../common'; import { createFunctionLimitExceededError, MessageOrChatEvent, @@ -39,6 +47,8 @@ import { extractMessages } from './extract_messages'; const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000; +const END_TURN = 'exit_loop'; + function executeFunctionAndCatchError({ name, args, @@ -123,17 +133,42 @@ function executeFunctionAndCatchError({ ); } -function getFunctionDefinitions({ +function getFunctionOptions({ functionClient, - functionLimitExceeded, disableFunctions, + functionLimitExceeded, }: { functionClient: ChatFunctionClient; - functionLimitExceeded: boolean; disableFunctions: boolean; -}) { - if (functionLimitExceeded || disableFunctions === true) { - return []; + functionLimitExceeded: boolean; +}): { + functions?: Array<{ name: string; description: string; parameters?: CompatibleJSONSchema }>; + functionCall?: string; +} { + if (disableFunctions === true) { + return {}; + } + + if (functionLimitExceeded) { + return { + functionCall: END_TURN, + functions: [ + { + name: END_TURN, + description: `You've run out of tool calls. Call this tool, and explain to the user you've run out of budget.`, + parameters: { + type: 'object', + properties: { + response: { + type: 'string', + description: 'Your textual response', + }, + }, + required: ['response'], + }, + }, + ], + }; } const systemFunctions = functionClient @@ -147,7 +182,7 @@ function getFunctionDefinitions({ .concat(actions) .map((definition) => pick(definition, 'name', 'description', 'parameters')); - return allDefinitions; + return { functions: allDefinitions }; } export function continueConversation({ @@ -179,13 +214,14 @@ export function continueConversation({ const functionLimitExceeded = functionCallsLeft <= 0; - const functionDefinitions = getFunctionDefinitions({ - functionLimitExceeded, + const functionOptions = getFunctionOptions({ functionClient, disableFunctions, + functionLimitExceeded, }); const lastMessage = last(initialMessages)?.message; + const isUserMessage = lastMessage?.role === MessageRole.User; return executeNextStep().pipe(handleEvents()); @@ -199,9 +235,9 @@ export function continueConversation({ return chat(operationName, { messages: initialMessages, - functions: functionDefinitions, connectorId, stream: true, + ...functionOptions, }).pipe(emitWithConcatenatedMessage(), catchFunctionNotFoundError(functionLimitExceeded)); } @@ -282,7 +318,32 @@ export function continueConversation({ function handleEvents(): OperatorFunction { return (events$) => { - const shared$ = events$.pipe(shareReplay()); + const shared$ = events$.pipe( + shareReplay(), + map((event) => { + if (event.type === StreamingChatResponseEventType.MessageAdd) { + const message = event.message; + + if (message.message.function_call?.name === END_TURN) { + const args = JSON.parse(message.message.function_call.arguments ?? '{}') as { + response: string; + }; + + return { + ...event, + message: { + ...message, + message: { + ...omit(message.message, 'function_call', 'content'), + content: args.response ?? `The model returned an empty response`, + }, + }, + } satisfies MessageAddEvent; + } + } + return event; + }) + ); return concat( shared$, diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts index 2ec62a3c7ad30..3a6f7cba22893 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts @@ -48,19 +48,7 @@ export function registerQueryFunction({ If the user asks for a query, and one of the dataset info functions was called and returned no results, you should still call the query function to generate an example query. Even if the "${QUERY_FUNCTION_NAME}" function was used before that, follow it up with the "${QUERY_FUNCTION_NAME}" function. If a query fails, do not attempt to correct it yourself. Again you should call the "${QUERY_FUNCTION_NAME}" function, - even if it has been called before. - - ${ - availableFunctionNames.includes(VISUALIZE_QUERY_NAME) - ? `When the "${VISUALIZE_QUERY_NAME}" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "${VISUALIZE_QUERY_NAME}" function call with your own visualization attempt` - : '' - } - - ${ - availableFunctionNames.includes(EXECUTE_QUERY_NAME) - ? `If the "${EXECUTE_QUERY_NAME}" function has been called, summarize these results for the user. The user does not see a visualization in this case.` - : '' - }`; + even if it has been called before.`; }); functions.registerFunction( @@ -135,14 +123,16 @@ export function registerQueryFunction({ const actions = functions.getActions(); + const inferenceMessages = convertMessagesForInference( + // remove system message and query function request + messages.filter((message) => message.message.role !== MessageRole.System).slice(0, -1), + resources.logger + ); + const events$ = naturalLanguageToEsql({ client: pluginsStart.inference.getClient({ request: resources.request }), connectorId, - messages: convertMessagesForInference( - // remove system message and query function request - messages.filter((message) => message.message.role !== MessageRole.System).slice(0, -1), - resources.logger - ), + messages: inferenceMessages, logger: resources.logger, tools: Object.fromEntries( [...actions, ...esqlFunctions].map((fn) => [ @@ -151,6 +141,7 @@ export function registerQueryFunction({ ]) ), functionCalling: simulateFunctionCalling ? 'simulated' : 'auto', + maxRetries: 0, }); const chatMessageId = v4(); From 9297b6cffdab7da6b5a6e5feeb47e89d0e8c4157 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Wed, 2 Jul 2025 12:08:35 +0200 Subject: [PATCH 02/16] Add raw_request to trace --- .../server/connector_types/bedrock/bedrock.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/bedrock.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/bedrock.ts index 040a3179ac314..99ca1ee21e302 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/bedrock.ts +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/bedrock/bedrock.ts @@ -7,6 +7,7 @@ import type { ServiceParams } from '@kbn/actions-plugin/server'; import { SubActionConnector } from '@kbn/actions-plugin/server'; +import { trace } from '@opentelemetry/api'; import aws from 'aws4'; import { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'; import type { SmithyMessageDecoderStream } from '@smithy/eventstream-codec'; @@ -590,6 +591,9 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B const signed = this.signRequest(requestBody, path, true); + const parentSpan = trace.getActiveSpan(); + parentSpan?.setAttribute('bedrock.raw_request', requestBody); + const response = await this.request( { ...signed, From 36d6c9046aa96484bbf8afb4f8231b92316f1433 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Wed, 2 Jul 2025 12:12:06 +0200 Subject: [PATCH 03/16] Rename `END_TURN` to `EXIT_LOOP_FUNCTION_NAME` --- .../service/client/operators/continue_conversation.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts index ba0d5eb393247..03156d1bd4665 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts @@ -47,7 +47,7 @@ import { extractMessages } from './extract_messages'; const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000; -const END_TURN = 'exit_loop'; +const EXIT_LOOP_FUNCTION_NAME = 'exit_loop'; function executeFunctionAndCatchError({ name, @@ -151,10 +151,10 @@ function getFunctionOptions({ if (functionLimitExceeded) { return { - functionCall: END_TURN, + functionCall: EXIT_LOOP_FUNCTION_NAME, functions: [ { - name: END_TURN, + name: EXIT_LOOP_FUNCTION_NAME, description: `You've run out of tool calls. Call this tool, and explain to the user you've run out of budget.`, parameters: { type: 'object', @@ -324,7 +324,7 @@ export function continueConversation({ if (event.type === StreamingChatResponseEventType.MessageAdd) { const message = event.message; - if (message.message.function_call?.name === END_TURN) { + if (message.message.function_call?.name === EXIT_LOOP_FUNCTION_NAME) { const args = JSON.parse(message.message.function_call.arguments ?? '{}') as { response: string; }; From 69dca2fd98f2694be23dded19ba411165a58509a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Thu, 3 Jul 2025 08:35:25 +0200 Subject: [PATCH 04/16] Add unit test --- .../convert_messages_for_inference.test.ts | 277 ++++++++++++++++++ .../common/convert_messages_for_inference.ts | 2 +- 2 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts new file mode 100644 index 0000000000000..1b6c27ce55477 --- /dev/null +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts @@ -0,0 +1,277 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { collapseMessages } from './convert_messages_for_inference'; +import { Message, MessageRole } from './types'; + +const mockLogger = { + error: jest.fn(), + debug: jest.fn(), + warn: jest.fn(), + trace: jest.fn(), +}; + +const queryToolCalls: Message[] = [ + { + '@timestamp': '2025-07-02T10:00:00Z', + message: { + role: MessageRole.User, + function_call: { + name: 'query', + arguments: '', + trigger: MessageRole.Assistant, + }, + content: + 'I can see that we have logs indices with a `client.ip` field. Let me create a query to count the unique IP addresses.', + }, + }, + { + '@timestamp': '2025-07-02T10:00:01Z', + message: { + role: MessageRole.User, + data: JSON.stringify({ + keywords: ['STATS', 'COUNT_DISTINCT'], + requestedDocumentation: { + STATS: 'Aggregates data using statistical functions.', + COUNT_DISTINCT: 'Counts distinct values in a field.', + }, + }), + name: 'query', + content: '{}', + }, + }, +]; + +const executeQueryToolCall: Message[] = [ + { + '@timestamp': '2025-07-02T10:01:00Z', + message: { + role: MessageRole.Assistant, + function_call: { + name: 'execute_query', + arguments: '{"query":"FROM logs"}', + trigger: MessageRole.Assistant, + }, + content: undefined, + }, + }, + { + '@timestamp': '2025-07-02T10:01:01Z', + message: { + role: MessageRole.User, + name: 'execute_query', + content: JSON.stringify({ + columns: [{ id: 'unique_ips', name: 'unique_ips', meta: { type: 'number' } }], + rows: [[324567]], + }), + }, + }, +]; + +describe('collapseMessages', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should not collapse messages if there are no query messages', () => { + const messages: Message[] = [ + { + '@timestamp': '2025-07-02T10:00:00Z', + message: { role: MessageRole.User, content: 'hello' }, + }, + { + '@timestamp': '2025-07-02T10:01:00Z', + message: { role: MessageRole.Assistant, content: 'hi there' }, + }, + ]; + const result = collapseMessages(messages, mockLogger); + expect(result).toEqual(messages); + }); + + it('should not collapse a query message if there are no messages after it', () => { + const messages: Message[] = [ + { + '@timestamp': '2025-07-02T10:00:00Z', + message: { role: MessageRole.User, content: 'hello' }, + }, + ...queryToolCalls, + ]; + const result = collapseMessages(messages, mockLogger); + expect(result).toEqual(messages); + }); + + describe('when collapsing an "execute_query" tool call', () => { + let collapsedMessages: Message[]; + beforeEach(() => { + const messages: Message[] = [...queryToolCalls, ...executeQueryToolCall]; + collapsedMessages = collapseMessages(messages, mockLogger); + }); + + it('should collapse the "execute_query" tool call into the "query" tool response', () => { + expect(collapsedMessages).toEqual([ + { + '@timestamp': expect.any(String), + message: { + content: expect.any(String), + function_call: { arguments: '', name: 'query', trigger: 'assistant' }, + role: 'user', + }, + }, + { + '@timestamp': expect.any(String), + message: { + content: expect.stringContaining('execute_query'), + data: expect.any(String), + name: 'query', + role: 'user', + }, + }, + ]); + }); + + it('should retain the query tool request', () => { + expect(collapsedMessages[0]).toEqual(queryToolCalls[0]); + }); + + it('should contain "execute_query" in "steps" property', () => { + const content = JSON.parse(collapsedMessages[1].message.content!); + + expect(content.steps).toHaveLength(2); + expect(content.steps[0].role).toBe('assistant'); + expect(content.steps[1].role).toBe('tool'); + + expect(content.steps).toEqual([ + { + content: null, + role: 'assistant', + toolCalls: [ + { + function: { arguments: { query: 'FROM logs' }, name: 'execute_query' }, + toolCallId: expect.any(String), + }, + ], + }, + { + name: 'execute_query', + response: { + columns: [{ id: 'unique_ips', meta: { type: 'number' }, name: 'unique_ips' }], + rows: [[324567]], + }, + role: 'tool', + toolCallId: expect.any(String), + }, + ]); + }); + }); + + describe('when a query message is followed by "visualize_query" tool pair', () => { + let collapsedMessages: Message[]; + beforeEach(() => { + const visualizeQueryToolCall: Message[] = [ + { + '@timestamp': '2025-07-02T10:01:00Z', + message: { + role: MessageRole.Assistant, + function_call: { + name: 'visualize_query', + arguments: '{"query":"FROM logs | STATS count() BY response.keyword"}', + trigger: MessageRole.Assistant, + }, + content: undefined, + }, + }, + { + '@timestamp': '2025-07-02T10:01:01Z', + message: { + role: MessageRole.User, + name: 'visualize_query', + content: JSON.stringify({ viz: 'some vega spec' }), + }, + }, + ]; + const messages: Message[] = [...queryToolCalls, ...visualizeQueryToolCall]; + collapsedMessages = collapseMessages(messages, mockLogger); + }); + + it('should collapse "visualize_query" into the "query" response', () => { + expect(collapsedMessages[1].message.content).toContain('visualize_query'); + }); + + it('should serialize the collapsed steps correctly', () => { + const content = JSON.parse(collapsedMessages[1].message.content!); + expect(content.steps).toHaveLength(2); + expect(content.steps[0].toolCalls[0].function.name).toBe('visualize_query'); + }); + }); + + describe('when an unrelated tool call is present', () => { + let result: Message[]; + beforeEach(() => { + const messages: Message[] = [ + ...queryToolCalls, + ...executeQueryToolCall, + { + '@timestamp': '2025-07-02T10:02:00Z', + message: { + role: MessageRole.Assistant, + function_call: { + name: 'some_other_function', + arguments: '{"user":"george"}', + trigger: MessageRole.Assistant, + }, + content: undefined, + }, + }, + ]; + result = collapseMessages(messages, mockLogger); + }); + + it('should stop collapsing and preserve the unrelated tool call', () => { + expect(result).toHaveLength(3); + }); + + it('should add "execute_query" to the "query" response', () => { + const queryToolResponse = result[1]; + expect(queryToolResponse.message.content).toContain('execute_query'); + }); + + it('should retain the unrelated tool call as the last message', () => { + expect(result[2].message.function_call?.name).toEqual('some_other_function'); + }); + }); + + describe('when there are multiple query messages', () => { + let result: Message[]; + beforeEach(() => { + const messages: Message[] = [ + ...queryToolCalls, + ...executeQueryToolCall, + ...queryToolCalls, + ...executeQueryToolCall, + ]; + result = collapseMessages(messages, mockLogger); + }); + + it('should return four messages', () => { + expect(result).toHaveLength(4); + }); + + it('should collapse the first query correctly', () => { + const firstQueryResponse = result[1]; + const firstQueryContent = JSON.parse(firstQueryResponse.message.content!); + expect(firstQueryContent.steps).toHaveLength(2); + }); + + it('should collapse the second query correctly', () => { + const secondQueryResponse = result[3]; + const secondQueryContent = JSON.parse(secondQueryResponse.message.content!); + expect(secondQueryContent.steps).toHaveLength(2); + }); + }); +}); + +describe('convertMessagesForInference', () => {}); diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts index 4d1f77554a235..1e8a449a25c06 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts @@ -27,7 +27,7 @@ function safeJsonParse(jsonString: string | undefined, logger: Pick) { +export function collapseMessages(messages: Message[], logger: Pick) { const collapsed: Message[] = []; for (let i = 0; i < messages.length; i++) { From ae1c6d6fd8bf6c18bed517efc2d152cc154a2027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Thu, 3 Jul 2025 08:41:57 +0200 Subject: [PATCH 05/16] Improve naming in unit test --- .../convert_messages_for_inference.test.ts | 35 ++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts index 1b6c27ce55477..d3e02c309ef32 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts @@ -88,8 +88,8 @@ describe('collapseMessages', () => { message: { role: MessageRole.Assistant, content: 'hi there' }, }, ]; - const result = collapseMessages(messages, mockLogger); - expect(result).toEqual(messages); + const collapsedMessages = collapseMessages(messages, mockLogger); + expect(collapsedMessages).toEqual(messages); }); it('should not collapse a query message if there are no messages after it', () => { @@ -100,8 +100,8 @@ describe('collapseMessages', () => { }, ...queryToolCalls, ]; - const result = collapseMessages(messages, mockLogger); - expect(result).toEqual(messages); + const collapsedMessages = collapseMessages(messages, mockLogger); + expect(collapsedMessages).toEqual(messages); }); describe('when collapsing an "execute_query" tool call', () => { @@ -143,6 +143,7 @@ describe('collapseMessages', () => { expect(content.steps).toHaveLength(2); expect(content.steps[0].role).toBe('assistant'); expect(content.steps[1].role).toBe('tool'); + expect(content.steps[1].name).toBe('execute_query'); expect(content.steps).toEqual([ { @@ -204,12 +205,14 @@ describe('collapseMessages', () => { it('should serialize the collapsed steps correctly', () => { const content = JSON.parse(collapsedMessages[1].message.content!); expect(content.steps).toHaveLength(2); - expect(content.steps[0].toolCalls[0].function.name).toBe('visualize_query'); + expect(content.steps[0].role).toBe('assistant'); + expect(content.steps[1].role).toBe('tool'); + expect(content.steps[1].name).toBe('visualize_query'); }); }); describe('when an unrelated tool call is present', () => { - let result: Message[]; + let collapsedMessages: Message[]; beforeEach(() => { const messages: Message[] = [ ...queryToolCalls, @@ -220,32 +223,32 @@ describe('collapseMessages', () => { role: MessageRole.Assistant, function_call: { name: 'some_other_function', - arguments: '{"user":"george"}', + arguments: JSON.stringify({ user: 'george' }), trigger: MessageRole.Assistant, }, content: undefined, }, }, ]; - result = collapseMessages(messages, mockLogger); + collapsedMessages = collapseMessages(messages, mockLogger); }); it('should stop collapsing and preserve the unrelated tool call', () => { - expect(result).toHaveLength(3); + expect(collapsedMessages).toHaveLength(3); }); it('should add "execute_query" to the "query" response', () => { - const queryToolResponse = result[1]; + const queryToolResponse = collapsedMessages[1]; expect(queryToolResponse.message.content).toContain('execute_query'); }); it('should retain the unrelated tool call as the last message', () => { - expect(result[2].message.function_call?.name).toEqual('some_other_function'); + expect(collapsedMessages[2].message.function_call?.name).toEqual('some_other_function'); }); }); describe('when there are multiple query messages', () => { - let result: Message[]; + let collapsedMessages: Message[]; beforeEach(() => { const messages: Message[] = [ ...queryToolCalls, @@ -253,21 +256,21 @@ describe('collapseMessages', () => { ...queryToolCalls, ...executeQueryToolCall, ]; - result = collapseMessages(messages, mockLogger); + collapsedMessages = collapseMessages(messages, mockLogger); }); it('should return four messages', () => { - expect(result).toHaveLength(4); + expect(collapsedMessages).toHaveLength(4); }); it('should collapse the first query correctly', () => { - const firstQueryResponse = result[1]; + const firstQueryResponse = collapsedMessages[1]; const firstQueryContent = JSON.parse(firstQueryResponse.message.content!); expect(firstQueryContent.steps).toHaveLength(2); }); it('should collapse the second query correctly', () => { - const secondQueryResponse = result[3]; + const secondQueryResponse = collapsedMessages[3]; const secondQueryContent = JSON.parse(secondQueryResponse.message.content!); expect(secondQueryContent.steps).toHaveLength(2); }); From 45441274a6558261819db6e023ccc2356626fe0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Thu, 3 Jul 2025 12:09:43 +0200 Subject: [PATCH 06/16] Remove failing test --- .../convert_messages_for_inference.test.ts | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts index d3e02c309ef32..782be8347449c 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts @@ -246,35 +246,6 @@ describe('collapseMessages', () => { expect(collapsedMessages[2].message.function_call?.name).toEqual('some_other_function'); }); }); - - describe('when there are multiple query messages', () => { - let collapsedMessages: Message[]; - beforeEach(() => { - const messages: Message[] = [ - ...queryToolCalls, - ...executeQueryToolCall, - ...queryToolCalls, - ...executeQueryToolCall, - ]; - collapsedMessages = collapseMessages(messages, mockLogger); - }); - - it('should return four messages', () => { - expect(collapsedMessages).toHaveLength(4); - }); - - it('should collapse the first query correctly', () => { - const firstQueryResponse = collapsedMessages[1]; - const firstQueryContent = JSON.parse(firstQueryResponse.message.content!); - expect(firstQueryContent.steps).toHaveLength(2); - }); - - it('should collapse the second query correctly', () => { - const secondQueryResponse = collapsedMessages[3]; - const secondQueryContent = JSON.parse(secondQueryResponse.message.content!); - expect(secondQueryContent.steps).toHaveLength(2); - }); - }); }); describe('convertMessagesForInference', () => {}); From 2f0cf17cda938373b86f8b0d1216fc76ab82f4a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Thu, 3 Jul 2025 15:06:23 +0200 Subject: [PATCH 07/16] Capture LLM response when tool validation fails --- .../src/with_chat_complete_span.ts | 5 ++++- .../chat_complete/utils/chunks_into_message.ts | 14 ++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/x-pack/platform/packages/shared/kbn-inference-tracing/src/with_chat_complete_span.ts b/x-pack/platform/packages/shared/kbn-inference-tracing/src/with_chat_complete_span.ts index 3a83c1304e306..3534c717706de 100644 --- a/x-pack/platform/packages/shared/kbn-inference-tracing/src/with_chat_complete_span.ts +++ b/x-pack/platform/packages/shared/kbn-inference-tracing/src/with_chat_complete_span.ts @@ -47,7 +47,10 @@ function addEvent(span: Span, event: MessageEvent) { }); } -function setChoice(span: Span, { content, toolCalls }: { content: string; toolCalls: ToolCall[] }) { +export function setChoice( + span: Span, + { content, toolCalls }: { content: string; toolCalls: ToolCall[] } +) { addEvent(span, { name: GenAISemanticConventions.GenAIChoice, body: { diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts index 3aa5f7815e019..a7ad5a7845161 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts @@ -13,8 +13,10 @@ import { ToolOptions, withoutTokenCountEvents, } from '@kbn/inference-common'; +import { trace } from '@opentelemetry/api'; import type { Logger } from '@kbn/logging'; import { OperatorFunction, map, merge, share, toArray } from 'rxjs'; +import { setChoice } from '@kbn/inference-tracing/src/with_chat_complete_span'; import { validateToolCalls } from '../../util/validate_tool_calls'; import { mergeChunks } from './merge_chunks'; @@ -43,10 +45,14 @@ export function chunksIntoMessage({ logger.debug(() => `Received completed message: ${JSON.stringify(concatenatedChunk)}`); - const validatedToolCalls = validateToolCalls({ - ...toolOptions, - toolCalls: concatenatedChunk.tool_calls, - }); + const content = concatenatedChunk.content; + const toolCalls = concatenatedChunk.tool_calls; + const activeSpan = trace.getActiveSpan(); + if (activeSpan) { + setChoice(activeSpan, { content, toolCalls }); + } + + const validatedToolCalls = validateToolCalls({ ...toolOptions, toolCalls }); return { type: ChatCompletionEventType.ChatCompletionMessage, From d9e0c525de88eb980d7371f6aad87ec7203b4235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 4 Jul 2025 00:22:47 +0200 Subject: [PATCH 08/16] Fix API tests --- .../complete/functions/execute_query.spec.ts | 178 ++++++++++++++++-- 1 file changed, 158 insertions(+), 20 deletions(-) diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts index 320da35ea016f..3f29f8d3f5720 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts @@ -11,6 +11,7 @@ import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace'; import { last } from 'lodash'; import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; import { type EsqlToRecords } from '@elastic/elasticsearch/lib/helpers'; +import { ChatCompletionMessageEvent } from '@kbn/inference-common'; import { LlmProxy, createLlmProxy } from '../../utils/create_llm_proxy'; import { chatComplete } from '../../utils/conversation'; import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context'; @@ -170,13 +171,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); }); - describe('The fourth request - Executing the ES|QL query', () => { - it('contains the `execute_query` tool call request', () => { + describe('The fourth request - executing the ES|QL query', () => { + it('should not contain the `execute_query` tool call request', () => { const hasToolCall = fourthRequestBody.messages.some( // @ts-expect-error (message) => message.tool_calls?.[0]?.function?.name === 'execute_query' ); - expect(hasToolCall).to.be(true); + expect(hasToolCall).to.be(false); }); it('emits a messageAdded event with the `execute_query` tool response', () => { @@ -186,31 +187,168 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon expect(event?.message.message.content).to.contain('simple log message'); }); - describe('the `execute_query` tool call response', () => { - let toolCallResponse: { - columns: EsqlToRecords['columns']; - rows: EsqlToRecords['records']; - }; - before(async () => { - toolCallResponse = JSON.parse(last(fourthRequestBody.messages)?.content as string); - }); + describe('tool call collapsing', () => { + it('collapses the `execute_query` tool call into the `query` tool response', () => { + const content = JSON.parse(last(fourthRequestBody.messages)?.content as string); + expect(content.steps).to.have.length(2); + + const [toolRequest, toolResponse] = content.steps; - it('has the correct columns', () => { - expect(toolCallResponse.columns.map(({ name }) => name)).to.eql([ - 'message', - '@timestamp', - ]); + // visualize_query tool request (sent by the LLM) + expect(toolRequest.role).to.be('assistant'); + expect(toolRequest.toolCalls[0].function.name).to.be('execute_query'); + + // visualize_query tool response (sent by AI Assistant) + expect(toolResponse.role).to.be('tool'); + expect(toolResponse.name).to.be('execute_query'); }); - it('has the correct number of rows', () => { - expect(toolCallResponse.rows.length).to.be(10); + it('contains the `execute_query` tool call request', () => { + const toolCallRequest = JSON.parse(last(fourthRequestBody.messages)?.content as string) + .steps[0].toolCalls[0]; + expect(toolCallRequest.function.name).to.be('execute_query'); + expect(toolCallRequest.function.arguments.query).to.contain( + 'FROM logs-apache.access-default' + ); }); - it('has the right log message', () => { - expect(toolCallResponse.rows[0][0]).to.be('simple log message'); + describe('the `execute_query` response', () => { + let toolCallResponse: { + columns: EsqlToRecords['columns']; + rows: EsqlToRecords['records']; + }; + + before(async () => { + toolCallResponse = JSON.parse(last(fourthRequestBody.messages)?.content as string) + .steps[1].response; + }); + + it('has the correct columns', () => { + expect(toolCallResponse.columns.map(({ name }) => name)).to.eql([ + 'message', + '@timestamp', + ]); + }); + + it('has the correct number of rows', () => { + expect(toolCallResponse.rows.length).to.be(10); + }); + + it('has the right log message', () => { + expect(toolCallResponse.rows[0][0]).to.be('simple log message'); + }); }); }); }); }); }); } + +// query tool call +// [ +// { +// "role": "assistant", +// "content": "", +// "tool_calls": [ +// { +// "function": { +// "name": "query", +// "arguments": "{}" +// }, +// "id": "5af197", +// "type": "function" +// } +// ] +// }, +// { +// "role": "tool", +// "content": "{\"steps\":[{\"role\":\"assistant\",\"content\":\"\",\"toolCalls\":[{\"function\":{\"name\":\"execute_query\",\"arguments\":{\"query\":\"FROM logs-apache.access-default\\n | KEEP message\\n | SORT @timestamp DESC\\n | LIMIT 10\"}},\"toolCallId\":\"ce4275\"}]},{\"name\":\"execute_query\",\"role\":\"tool\",\"response\":{\"columns\":[{\"id\":\"message\",\"name\":\"message\",\"meta\":{\"type\":\"string\"}},{\"id\":\"@timestamp\",\"name\":\"@timestamp\",\"meta\":{\"type\":\"date\"}}],\"rows\":[[\"simple log message\",\"2025-07-03T21:43:04.898Z\"],[\"simple log message\",\"2025-07-03T21:42:04.898Z\"],[\"simple log message\",\"2025-07-03T21:41:04.898Z\"],[\"simple log message\",\"2025-07-03T21:40:04.898Z\"],[\"simple log message\",\"2025-07-03T21:39:04.898Z\"],[\"simple log message\",\"2025-07-03T21:38:04.898Z\"],[\"simple log message\",\"2025-07-03T21:37:04.898Z\"],[\"simple log message\",\"2025-07-03T21:36:04.898Z\"],[\"simple log message\",\"2025-07-03T21:35:04.898Z\"],[\"simple log message\",\"2025-07-03T21:34:04.898Z\"]]},\"toolCallId\":\"ce4275\"}]}", +// "tool_call_id": "5af197" +// } +// ] + +// deserialized content of the query tool call +// { +// "steps": [ +// { +// "role": "assistant", +// "content": "", +// "toolCalls": [ +// { +// "function": { +// "name": "execute_query", +// "arguments": { +// "query": "FROM logs-apache.access-default\n | KEEP message\n | SORT @timestamp DESC\n | LIMIT 10" +// } +// }, +// "toolCallId": "ce4275" +// } +// ] +// }, +// { +// "name": "execute_query", +// "role": "tool", +// "response": { +// "columns": [ +// { +// "id": "message", +// "name": "message", +// "meta": { +// "type": "string" +// } +// }, +// { +// "id": "@timestamp", +// "name": "@timestamp", +// "meta": { +// "type": "date" +// } +// } +// ], +// "rows": [ +// [ +// "simple log message", +// "2025-07-03T21:43:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:42:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:41:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:40:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:39:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:38:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:37:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:36:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:35:04.898Z" +// ], +// [ +// "simple log message", +// "2025-07-03T21:34:04.898Z" +// ] +// ] +// }, +// "toolCallId": "ce4275" +// } +// ] +// } From 3e7e50875b46e7adf13f6931946f7b6fada576db Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Thu, 3 Jul 2025 22:49:22 +0000 Subject: [PATCH 09/16] [CI] Auto-commit changed files from 'node scripts/eslint_all_files --no-cache --fix' --- .../ai_assistant/complete/functions/execute_query.spec.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts index 3f29f8d3f5720..13b32c4a3fa7b 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts @@ -11,7 +11,6 @@ import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace'; import { last } from 'lodash'; import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; import { type EsqlToRecords } from '@elastic/elasticsearch/lib/helpers'; -import { ChatCompletionMessageEvent } from '@kbn/inference-common'; import { LlmProxy, createLlmProxy } from '../../utils/create_llm_proxy'; import { chatComplete } from '../../utils/conversation'; import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context'; From 2ee1caaee78c4cc796e6534267eb809ec22d6ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 4 Jul 2025 09:35:17 +0200 Subject: [PATCH 10/16] Remove temperature --- .../observability_ai_assistant/server/service/client/index.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts index 470312d3672a7..41564926cf35b 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts @@ -509,7 +509,6 @@ export class ObservabilityAIAssistantClient { this.dependencies.inferenceClient .chatComplete({ ...options, - temperature: 0.25, stream: true, maxRetries: 0, messages: convertMessagesForInference(redactedMessages, this.dependencies.logger), @@ -537,7 +536,6 @@ export class ObservabilityAIAssistantClient { return this.dependencies.inferenceClient.chatComplete({ ...options, messages: convertMessagesForInference(messages, this.dependencies.logger), - temperature: 0.25, maxRetries: 0, stream: false, }) as TStream extends true ? never : Promise; From fe9826e57a3e7c6103193290c929e0d98e2a9359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 4 Jul 2025 09:39:51 +0200 Subject: [PATCH 11/16] nit: move to single line --- .../server/chat_complete/utils/chunks_into_message.ts | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts index a7ad5a7845161..41c6e60bef31a 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts @@ -45,8 +45,7 @@ export function chunksIntoMessage({ logger.debug(() => `Received completed message: ${JSON.stringify(concatenatedChunk)}`); - const content = concatenatedChunk.content; - const toolCalls = concatenatedChunk.tool_calls; + const { content, tool_calls: toolCalls } = concatenatedChunk; const activeSpan = trace.getActiveSpan(); if (activeSpan) { setChoice(activeSpan, { content, toolCalls }); @@ -56,7 +55,7 @@ export function chunksIntoMessage({ return { type: ChatCompletionEventType.ChatCompletionMessage, - content: concatenatedChunk.content, + content, toolCalls: validatedToolCalls, }; }) From dd1fd3138a7566a08432cdb4562d16d23def28a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 4 Jul 2025 12:29:30 +0200 Subject: [PATCH 12/16] Rename `collapseMessages` to `collapseInternalToolCalls` --- .../common/convert_messages_for_inference.test.ts | 14 +++++++------- .../common/convert_messages_for_inference.ts | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts index 782be8347449c..311528be57793 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { collapseMessages } from './convert_messages_for_inference'; +import { collapseInternalToolCalls } from './convert_messages_for_inference'; import { Message, MessageRole } from './types'; const mockLogger = { @@ -72,7 +72,7 @@ const executeQueryToolCall: Message[] = [ }, ]; -describe('collapseMessages', () => { +describe('collapseInternalToolCalls', () => { beforeEach(() => { jest.clearAllMocks(); }); @@ -88,7 +88,7 @@ describe('collapseMessages', () => { message: { role: MessageRole.Assistant, content: 'hi there' }, }, ]; - const collapsedMessages = collapseMessages(messages, mockLogger); + const collapsedMessages = collapseInternalToolCalls(messages, mockLogger); expect(collapsedMessages).toEqual(messages); }); @@ -100,7 +100,7 @@ describe('collapseMessages', () => { }, ...queryToolCalls, ]; - const collapsedMessages = collapseMessages(messages, mockLogger); + const collapsedMessages = collapseInternalToolCalls(messages, mockLogger); expect(collapsedMessages).toEqual(messages); }); @@ -108,7 +108,7 @@ describe('collapseMessages', () => { let collapsedMessages: Message[]; beforeEach(() => { const messages: Message[] = [...queryToolCalls, ...executeQueryToolCall]; - collapsedMessages = collapseMessages(messages, mockLogger); + collapsedMessages = collapseInternalToolCalls(messages, mockLogger); }); it('should collapse the "execute_query" tool call into the "query" tool response', () => { @@ -195,7 +195,7 @@ describe('collapseMessages', () => { }, ]; const messages: Message[] = [...queryToolCalls, ...visualizeQueryToolCall]; - collapsedMessages = collapseMessages(messages, mockLogger); + collapsedMessages = collapseInternalToolCalls(messages, mockLogger); }); it('should collapse "visualize_query" into the "query" response', () => { @@ -230,7 +230,7 @@ describe('collapseMessages', () => { }, }, ]; - collapsedMessages = collapseMessages(messages, mockLogger); + collapsedMessages = collapseInternalToolCalls(messages, mockLogger); }); it('should stop collapsing and preserve the unrelated tool call', () => { diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts index 1e8a449a25c06..f68915d886195 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts @@ -27,7 +27,7 @@ function safeJsonParse(jsonString: string | undefined, logger: Pick) { +export function collapseInternalToolCalls(messages: Message[], logger: Pick) { const collapsed: Message[] = []; for (let i = 0; i < messages.length; i++) { @@ -69,7 +69,7 @@ export function convertMessagesForInference( ): InferenceMessage[] { const inferenceMessages: InferenceMessage[] = []; - const collapsedMessages: Message[] = collapseMessages(messages, logger); + const collapsedMessages: Message[] = collapseInternalToolCalls(messages, logger); collapsedMessages.forEach((message, idx) => { if (message.message.role === MessageRole.Assistant) { From 45a6460fed1de3ade96d321e7231997c0b42aa35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 4 Jul 2025 12:35:03 +0200 Subject: [PATCH 13/16] Revert "Remove temperature" This reverts commit 2ee1caaee78c4cc796e6534267eb809ec22d6ca4. --- .../server/service/client/index.test.ts | 10 +++++++++- .../server/service/client/index.ts | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts index 59d6bd3b46656..2718ba747e998 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts @@ -296,6 +296,8 @@ describe('Observability AI Assistant client', () => { system: 'You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you.', functionCalling: 'auto', + maxRetries: 0, + temperature: 0.25, toolChoice: expect.objectContaining({ function: 'title_conversation', }), @@ -333,6 +335,8 @@ describe('Observability AI Assistant client', () => { { role: 'user', content: 'How many alerts do I have?' }, ]), functionCalling: 'auto', + maxRetries: 0, + temperature: 0.25, toolChoice: undefined, tools: undefined, metadata: { @@ -859,6 +863,8 @@ describe('Observability AI Assistant client', () => { { role: 'user', content: 'How many alerts do I have?' }, ]), functionCalling: 'auto', + maxRetries: 0, + temperature: 0.25, toolChoice: 'auto', tools: expect.any(Object), metadata: { @@ -1018,6 +1024,8 @@ describe('Observability AI Assistant client', () => { { role: 'user', content: 'How many alerts do I have?' }, ]), functionCalling: 'auto', + maxRetries: 0, + temperature: 0.25, toolChoice: 'auto', tools: expect.any(Object), metadata: { @@ -1285,7 +1293,7 @@ describe('Observability AI Assistant client', () => { })); stream = observableIntoStream( - await client.complete({ + client.complete({ connectorId: 'foo', messages: [user('How many alerts do I have?')], functionClient: functionClientMock, diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts index 41564926cf35b..470312d3672a7 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts @@ -509,6 +509,7 @@ export class ObservabilityAIAssistantClient { this.dependencies.inferenceClient .chatComplete({ ...options, + temperature: 0.25, stream: true, maxRetries: 0, messages: convertMessagesForInference(redactedMessages, this.dependencies.logger), @@ -536,6 +537,7 @@ export class ObservabilityAIAssistantClient { return this.dependencies.inferenceClient.chatComplete({ ...options, messages: convertMessagesForInference(messages, this.dependencies.logger), + temperature: 0.25, maxRetries: 0, stream: false, }) as TStream extends true ? never : Promise; From db5b388c3b700820f48fd2858fd7bee698d82a9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 4 Jul 2025 12:57:32 +0200 Subject: [PATCH 14/16] Fix unit test --- .../server/service/client/index.test.ts | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts index 2718ba747e998..68aede962deb4 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts @@ -8,7 +8,7 @@ import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client'; import type { CoreSetup, ElasticsearchClient, IUiSettingsClient, Logger } from '@kbn/core/server'; import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; import { waitFor } from '@testing-library/react'; -import { last, merge, repeat } from 'lodash'; +import { isEmpty, last, merge, repeat, size } from 'lodash'; import { Subject, Observable } from 'rxjs'; import { EventEmitter, type Readable } from 'stream'; import { finished } from 'stream/promises'; @@ -140,6 +140,7 @@ describe('Observability AI Assistant client', () => { // uncomment this line for debugging // const consoleOrPassThrough = console.log.bind(console); + const consoleOrPassThrough = () => {}; loggerMock = { @@ -1288,9 +1289,11 @@ describe('Observability AI Assistant client', () => { ]); functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts'); - functionClientMock.executeFunction.mockImplementation(async () => ({ - content: 'Call this function again', - })); + functionClientMock.executeFunction.mockImplementation(async () => { + return { + content: 'Call this function again', + }; + }); stream = observableIntoStream( client.complete({ @@ -1311,7 +1314,7 @@ describe('Observability AI Assistant client', () => { const body = inferenceClientMock.chatComplete.mock.lastCall![0]; let nextLlmCallPromise: Promise; - if (Object.keys(body.tools ?? {}).length) { + if (!isEmpty(body.tools) && body.tools.exit_loop === undefined) { nextLlmCallPromise = waitForNextLlmCall(); await llmSimulator.chunk({ function_call: { name: 'get_top_alerts', arguments: '{}' } }); } else { @@ -1341,9 +1344,19 @@ describe('Observability AI Assistant client', () => { const firstBody = inferenceClientMock.chatComplete.mock.calls[0][0] as any; const body = inferenceClientMock.chatComplete.mock.lastCall![0] as any; - expect(Object.keys(firstBody.tools ?? {}).length).toEqual(1); + expect(size(firstBody.tools)).toEqual(1); - expect(body.tools).toEqual(undefined); + expect(body.tools).toEqual({ + exit_loop: { + description: + "You've run out of tool calls. Call this tool, and explain to the user you've run out of budget.", + schema: { + properties: { response: { description: 'Your textual response', type: 'string' } }, + required: ['response'], + type: 'object', + }, + }, + }); }); }); From f39bee9a7482d501ee2bc9b443a48babbcf0764c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 4 Jul 2025 14:53:22 +0200 Subject: [PATCH 15/16] Update unit test --- .../adapters/bedrock/bedrock_claude_adapter.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts index 99251e6e53d8d..494fcba2fb3d9 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts @@ -390,7 +390,7 @@ Human:`, const { toolChoice, tools, system } = getCallParams(); expect(toolChoice).toBeUndefined(); - expect(tools).toEqual([]); + expect(tools).toEqual(undefined); // Claude requires tools to be undefined when no tools are available expect(system).toEqual([{ text: addNoToolUsageDirective('some system instruction') }]); }); From d43bb3646319e8e602c7064e478acee4dd8fec94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Louv-Jansen?= Date: Fri, 4 Jul 2025 17:23:32 +0200 Subject: [PATCH 16/16] Fix tests --- .../convert_messages_for_inference.test.ts | 248 +++++++++++++----- .../common/convert_messages_for_inference.ts | 2 +- 2 files changed, 181 insertions(+), 69 deletions(-) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts index 311528be57793..5685e5bdf6a9f 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { InferenceMessage } from '@elastic/elasticsearch/lib/api/types'; import { collapseInternalToolCalls } from './convert_messages_for_inference'; import { Message, MessageRole } from './types'; @@ -15,24 +16,70 @@ const mockLogger = { trace: jest.fn(), }; -const queryToolCalls: Message[] = [ +const userMessage: (msg: string) => Message = (msg: string) => ({ + '@timestamp': '2025-07-02T10:00:00Z', + message: { + role: MessageRole.User, + content: msg, + }, +}); + +const assistantMessage: (msg: string) => Message = (msg: string) => ({ + '@timestamp': '2025-07-02T10:01:00Z', + message: { + content: msg, + role: MessageRole.Assistant, + }, +}); + +const getDatasetInfoTool: Message[] = [ { - '@timestamp': '2025-07-02T10:00:00Z', + '@timestamp': '2025-07-02T10:01:00Z', message: { + content: + "I'll help you visualize logs from your system. First, let me check what log indices are available:", + function_call: { + name: 'get_dataset_info', + arguments: '{"index": "logs-*"}', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }, + { + '@timestamp': '2025-07-02T10:01:00Z', + message: { + content: JSON.stringify({ + indices: ['remote_cluster:logs-cloud_security_posture.scores-default'], + fields: ['@timestamp:date', 'log.level:keyword'], + stats: { + analyzed: 386, + total: 386, + }, + }), + name: 'get_dataset_info', role: MessageRole.User, + }, + }, +]; + +const queryTool: Message[] = [ + { + '@timestamp': '2025-07-04T14:32:53.974Z', + message: { + content: 'Now that I can see the available log indices, let me visualize some logs for you:', function_call: { name: 'query', arguments: '', trigger: MessageRole.Assistant, }, - content: - 'I can see that we have logs indices with a `client.ip` field. Let me create a query to count the unique IP addresses.', + role: MessageRole.Assistant, }, }, { - '@timestamp': '2025-07-02T10:00:01Z', + '@timestamp': '2025-07-04T14:32:57.331Z', message: { - role: MessageRole.User, + content: '{}', data: JSON.stringify({ keywords: ['STATS', 'COUNT_DISTINCT'], requestedDocumentation: { @@ -41,22 +88,22 @@ const queryToolCalls: Message[] = [ }, }), name: 'query', - content: '{}', + role: MessageRole.User, }, }, ]; -const executeQueryToolCall: Message[] = [ +const executeQueryTool: Message[] = [ { '@timestamp': '2025-07-02T10:01:00Z', message: { + content: undefined, role: MessageRole.Assistant, function_call: { name: 'execute_query', arguments: '{"query":"FROM logs"}', trigger: MessageRole.Assistant, }, - content: undefined, }, }, { @@ -72,6 +119,40 @@ const executeQueryToolCall: Message[] = [ }, ]; +const visualizeQueryTool: Message[] = [ + { + '@timestamp': '2025-07-04T14:33:03.937Z', + message: { + content: + "Now I'll create a visualization of your logs. Let me query the available logs and create a meaningful visualization:", + role: MessageRole.Assistant, + function_call: { + name: 'visualize_query', + arguments: '{"query":"FROM remote_cluster:logs-* | LIMIT 10","intention":"visualizeBar"}', + trigger: MessageRole.Assistant, + }, + }, + }, + { + '@timestamp': '2025-07-04T14:33:33.978Z', + message: { + content: JSON.stringify({ + errorMessages: ['Request timed out'], + message: + 'Only following query is visualized: ```esql\nFROM remote_cluster:logs-* | LIMIT 10\n```', + }), + data: JSON.stringify({ + columns: [], + rows: [], + correctedQuery: + 'FROM remote_cluster:logs-*\n| WHERE @timestamp >= NOW() - 24 hours\n| STATS count = COUNT(*) BY data_stream.dataset, log.level\n| SORT count DESC\n| LIMIT 10', + }), + name: 'visualize_query', + role: MessageRole.User, + }, + }, +]; + describe('collapseInternalToolCalls', () => { beforeEach(() => { jest.clearAllMocks(); @@ -98,53 +179,86 @@ describe('collapseInternalToolCalls', () => { '@timestamp': '2025-07-02T10:00:00Z', message: { role: MessageRole.User, content: 'hello' }, }, - ...queryToolCalls, + ...queryTool, ]; const collapsedMessages = collapseInternalToolCalls(messages, mockLogger); expect(collapsedMessages).toEqual(messages); }); - describe('when collapsing an "execute_query" tool call', () => { + describe('when a conversation contains a "query" followed by "execute_query" tool call', () => { let collapsedMessages: Message[]; + let messages: Message[]; beforeEach(() => { - const messages: Message[] = [...queryToolCalls, ...executeQueryToolCall]; + messages = [ + userMessage('Please analyze my logs'), + ...getDatasetInfoTool, + ...queryTool, + ...executeQueryTool, + assistantMessage('Here is the result'), + userMessage('What about the unique IPs?'), + ]; collapsedMessages = collapseInternalToolCalls(messages, mockLogger); }); - it('should collapse the "execute_query" tool call into the "query" tool response', () => { - expect(collapsedMessages).toEqual([ - { - '@timestamp': expect.any(String), - message: { - content: expect.any(String), - function_call: { arguments: '', name: 'query', trigger: 'assistant' }, - role: 'user', - }, - }, - { - '@timestamp': expect.any(String), - message: { - content: expect.stringContaining('execute_query'), - data: expect.any(String), - name: 'query', - role: 'user', - }, - }, + it('should have the right messages after collapsing', () => { + const formatMessages = (msg: Message) => ({ + role: msg.message.role, + toolName: msg.message.function_call?.name, + }); + + // before collapsing + expect(messages.map(formatMessages)).toEqual([ + { role: 'user' }, + { role: 'assistant', toolName: 'get_dataset_info' }, + { role: 'user' }, + { role: 'assistant', toolName: 'query' }, + { role: 'user' }, + { role: 'assistant', toolName: 'execute_query' }, + { role: 'user' }, + { role: 'assistant' }, + { role: 'user' }, ]); + + // after collapsing + expect(collapsedMessages.map(formatMessages)).toEqual([ + { role: 'user' }, + { role: 'assistant', toolName: 'get_dataset_info' }, + { role: 'user' }, + { role: 'assistant', toolName: 'query' }, + { role: 'user' }, + { role: 'assistant' }, + { role: 'user' }, + ]); + }); + + it('should retain the messages up until the query response', () => { + expect(messages.slice(0, 4)).toEqual(collapsedMessages.slice(0, 4)); + }); + + it('should retain the messages after the "execute_query" response', () => { + expect(messages.slice(-2)).toEqual(collapsedMessages.slice(-2)); }); - it('should retain the query tool request', () => { - expect(collapsedMessages[0]).toEqual(queryToolCalls[0]); + it('should remove the "execute_query" messages', () => { + expect(collapsedMessages).not.toContain(executeQueryTool[0]); + expect(collapsedMessages).not.toContain(executeQueryTool[1]); }); - it('should contain "execute_query" in "steps" property', () => { - const content = JSON.parse(collapsedMessages[1].message.content!); + it('should retain the "query" tool request', () => { + expect(collapsedMessages).toContain(queryTool[0]); + }); + + it('should collapse the "execute_query" calls into the "query" tool response', () => { + const queryToolResponse = collapsedMessages.find( + (msg) => msg.message.role === MessageRole.User && msg.message.name === 'query' + )!; + + const content = JSON.parse(queryToolResponse.message.content!); expect(content.steps).toHaveLength(2); expect(content.steps[0].role).toBe('assistant'); expect(content.steps[1].role).toBe('tool'); expect(content.steps[1].name).toBe('execute_query'); - expect(content.steps).toEqual([ { content: null, @@ -171,43 +285,41 @@ describe('collapseInternalToolCalls', () => { describe('when a query message is followed by "visualize_query" tool pair', () => { let collapsedMessages: Message[]; + let messages: Message[]; beforeEach(() => { - const visualizeQueryToolCall: Message[] = [ - { - '@timestamp': '2025-07-02T10:01:00Z', - message: { - role: MessageRole.Assistant, - function_call: { - name: 'visualize_query', - arguments: '{"query":"FROM logs | STATS count() BY response.keyword"}', - trigger: MessageRole.Assistant, - }, - content: undefined, - }, - }, - { - '@timestamp': '2025-07-02T10:01:01Z', - message: { - role: MessageRole.User, - name: 'visualize_query', - content: JSON.stringify({ viz: 'some vega spec' }), - }, - }, + messages = [ + userMessage('Please visualize my logs'), + ...getDatasetInfoTool, + ...queryTool, + ...visualizeQueryTool, + assistantMessage('Here is the result'), + userMessage('What about the unique IPs?'), ]; - const messages: Message[] = [...queryToolCalls, ...visualizeQueryToolCall]; + collapsedMessages = collapseInternalToolCalls(messages, mockLogger); }); it('should collapse "visualize_query" into the "query" response', () => { - expect(collapsedMessages[1].message.content).toContain('visualize_query'); - }); + const queryToolResponse = collapsedMessages.find( + (msg) => msg.message.role === MessageRole.User && msg.message.name === 'query' + )!; - it('should serialize the collapsed steps correctly', () => { - const content = JSON.parse(collapsedMessages[1].message.content!); - expect(content.steps).toHaveLength(2); - expect(content.steps[0].role).toBe('assistant'); - expect(content.steps[1].role).toBe('tool'); - expect(content.steps[1].name).toBe('visualize_query'); + const steps = JSON.parse(queryToolResponse.message.content!).steps as [ + InferenceMessage, + InferenceMessage + ]; + + const [toolCallRequest, toolCallResponse] = steps; + + expect(steps).toHaveLength(2); + + // @ts-expect-error + expect(toolCallRequest.toolCalls[0].function.name).toContain('visualize_query'); + expect(toolCallRequest.role).toContain('assistant'); + + // @ts-expect-error + expect(toolCallResponse.name).toContain('visualize_query'); + expect(toolCallResponse.role).toContain('tool'); }); }); @@ -215,8 +327,8 @@ describe('collapseInternalToolCalls', () => { let collapsedMessages: Message[]; beforeEach(() => { const messages: Message[] = [ - ...queryToolCalls, - ...executeQueryToolCall, + ...queryTool, + ...executeQueryTool, { '@timestamp': '2025-07-02T10:02:00Z', message: { diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts index f68915d886195..6ce7343e2a6b8 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts @@ -36,7 +36,7 @@ export function collapseInternalToolCalls(messages: Message[], logger: Pick { const name = msg.message.name || msg.message.function_call?.name; - return !name || ['query', 'visualize_query', 'execute_query'].includes(name); + return name && ['query', 'visualize_query', 'execute_query'].includes(name); }); if (messagesToCollapse.length) {