diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts index 31480588b12bf..394e2655bb850 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts @@ -409,7 +409,13 @@ describe('bedrockClaudeAdapter', () => { const { toolChoice, tools, system } = getCallParams(); expect(toolChoice).toBeUndefined(); - expect(tools).toEqual([]); + expect(tools).toEqual([ + { + description: 'myFunction', + input_schema: { properties: {}, type: 'object' }, + name: 'myFunction', + }, + ]); expect(system).toEqual(addNoToolUsageDirective('some system instruction')); }); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts index 9e81152efb87b..22a1df9910a5e 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts @@ -38,10 +38,12 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = { }) => { const noToolUsage = toolChoice === ToolChoiceType.none; + const bedRockTools = toolsToBedrock(tools, messages); + const subActionParams = { system: noToolUsage ? addNoToolUsageDirective(system) : system, messages: messagesToBedrock(messages), - tools: noToolUsage ? [] : toolsToBedrock(tools, messages), + tools: bedRockTools?.length ? bedRockTools : undefined, toolChoice: toolChoiceToBedrock(toolChoice), temperature, model: modelName, diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts index 3aa5f7815e019..b907aac375a4a 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/chunks_into_message.ts @@ -43,14 +43,13 @@ export function chunksIntoMessage({ logger.debug(() => `Received completed message: ${JSON.stringify(concatenatedChunk)}`); - const validatedToolCalls = validateToolCalls({ - ...toolOptions, - toolCalls: concatenatedChunk.tool_calls, - }); + const { content, tool_calls: toolCalls } = concatenatedChunk; + + const validatedToolCalls = validateToolCalls({ ...toolOptions, toolCalls }); return { type: ChatCompletionEventType.ChatCompletionMessage, - content: concatenatedChunk.content, + content, toolCalls: validatedToolCalls, }; }) diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts new file mode 100644 index 0000000000000..7a2d0c9fe8846 --- /dev/null +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.test.ts @@ -0,0 +1,356 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Message as InferenceMessage } from '@kbn/inference-common'; +import { collapseInternalToolCalls } from './convert_messages_for_inference'; +import { Message, MessageRole } from './types'; + +const userMessage: (msg: string) => Message = (msg: string) => ({ + '@timestamp': '2025-07-02T10:00:00Z', + message: { + role: MessageRole.User, + content: msg, + }, +}); + +const assistantMessage: (msg: string) => Message = (msg: string) => ({ + '@timestamp': '2025-07-02T10:01:00Z', + message: { + content: msg, + role: MessageRole.Assistant, + }, +}); + +const getDatasetInfoTool: Message[] = [ + { + '@timestamp': '2025-07-02T10:01:00Z', + message: { + content: + "I'll help you visualize logs from your system. First, let me check what log indices are available:", + function_call: { + name: 'get_dataset_info', + arguments: '{"index": "logs-*"}', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }, + { + '@timestamp': '2025-07-02T10:01:00Z', + message: { + content: JSON.stringify({ + indices: ['remote_cluster:logs-cloud_security_posture.scores-default'], + fields: ['@timestamp:date', 'log.level:keyword'], + stats: { + analyzed: 386, + total: 386, + }, + }), + name: 'get_dataset_info', + role: MessageRole.User, + }, + }, +]; + +const queryTool: Message[] = [ + { + '@timestamp': '2025-07-04T14:32:53.974Z', + message: { + content: 'Now that I can see the available log indices, let me visualize some logs for you:', + function_call: { + name: 'query', + arguments: '', + trigger: MessageRole.Assistant, + }, + role: MessageRole.Assistant, + }, + }, + { + '@timestamp': '2025-07-04T14:32:57.331Z', + message: { + content: '{}', + data: JSON.stringify({ + keywords: ['STATS', 'COUNT_DISTINCT'], + requestedDocumentation: { + STATS: 'Aggregates data using statistical functions.', + COUNT_DISTINCT: 'Counts distinct values in a field.', + }, + }), + name: 'query', + role: MessageRole.User, + }, + }, +]; + +const executeQueryTool: Message[] = [ + { + '@timestamp': '2025-07-02T10:01:00Z', + message: { + content: undefined, + role: MessageRole.Assistant, + function_call: { + name: 'execute_query', + arguments: '{"query":"FROM logs"}', + trigger: MessageRole.Assistant, + }, + }, + }, + { + '@timestamp': '2025-07-02T10:01:01Z', + message: { + role: MessageRole.User, + name: 'execute_query', + content: JSON.stringify({ + columns: [{ id: 'unique_ips', name: 'unique_ips', meta: { type: 'number' } }], + rows: [[324567]], + }), + }, + }, +]; + +const visualizeQueryTool: Message[] = [ + { + '@timestamp': '2025-07-04T14:33:03.937Z', + message: { + content: + "Now I'll create a visualization of your logs. Let me query the available logs and create a meaningful visualization:", + role: MessageRole.Assistant, + function_call: { + name: 'visualize_query', + arguments: '{"query":"FROM remote_cluster:logs-* | LIMIT 10","intention":"visualizeBar"}', + trigger: MessageRole.Assistant, + }, + }, + }, + { + '@timestamp': '2025-07-04T14:33:33.978Z', + message: { + content: JSON.stringify({ + errorMessages: ['Request timed out'], + message: + 'Only following query is visualized: ```esql\nFROM remote_cluster:logs-* | LIMIT 10\n```', + }), + data: JSON.stringify({ + columns: [], + rows: [], + correctedQuery: + 'FROM remote_cluster:logs-*\n| WHERE @timestamp >= NOW() - 24 hours\n| STATS count = COUNT(*) BY data_stream.dataset, log.level\n| SORT count DESC\n| LIMIT 10', + }), + name: 'visualize_query', + role: MessageRole.User, + }, + }, +]; + +describe('collapseInternalToolCalls', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should not collapse messages if there are no query messages', () => { + const messages: Message[] = [ + { + '@timestamp': '2025-07-02T10:00:00Z', + message: { role: MessageRole.User, content: 'hello' }, + }, + { + '@timestamp': '2025-07-02T10:01:00Z', + message: { role: MessageRole.Assistant, content: 'hi there' }, + }, + ]; + const collapsedMessages = collapseInternalToolCalls(messages); + expect(collapsedMessages).toEqual(messages); + }); + + it('should not collapse a query message if there are no messages after it', () => { + const messages: Message[] = [ + { + '@timestamp': '2025-07-02T10:00:00Z', + message: { role: MessageRole.User, content: 'hello' }, + }, + ...queryTool, + ]; + const collapsedMessages = collapseInternalToolCalls(messages); + expect(collapsedMessages).toEqual(messages); + }); + + describe('when a conversation contains a "query" followed by "execute_query" tool call', () => { + let collapsedMessages: Message[]; + let messages: Message[]; + beforeEach(() => { + messages = [ + userMessage('Please analyze my logs'), + ...getDatasetInfoTool, + ...queryTool, + ...executeQueryTool, + assistantMessage('Here is the result'), + userMessage('What about the unique IPs?'), + ]; + collapsedMessages = collapseInternalToolCalls(messages); + }); + + it('should have the right messages after collapsing', () => { + const formatMessages = (msg: Message) => ({ + role: msg.message.role, + toolName: msg.message.function_call?.name, + }); + + // before collapsing + expect(messages.map(formatMessages)).toEqual([ + { role: 'user' }, + { role: 'assistant', toolName: 'get_dataset_info' }, + { role: 'user' }, + { role: 'assistant', toolName: 'query' }, + { role: 'user' }, + { role: 'assistant', toolName: 'execute_query' }, + { role: 'user' }, + { role: 'assistant' }, + { role: 'user' }, + ]); + + // after collapsing + expect(collapsedMessages.map(formatMessages)).toEqual([ + { role: 'user' }, + { role: 'assistant', toolName: 'get_dataset_info' }, + { role: 'user' }, + { role: 'assistant', toolName: 'query' }, + { role: 'user' }, + { role: 'assistant' }, + { role: 'user' }, + ]); + }); + + it('should retain the messages up until the query response', () => { + expect(messages.slice(0, 4)).toEqual(collapsedMessages.slice(0, 4)); + }); + + it('should retain the messages after the "execute_query" response', () => { + expect(messages.slice(-2)).toEqual(collapsedMessages.slice(-2)); + }); + + it('should remove the "execute_query" messages', () => { + expect(collapsedMessages).not.toContain(executeQueryTool[0]); + expect(collapsedMessages).not.toContain(executeQueryTool[1]); + }); + + it('should retain the "query" tool request', () => { + expect(collapsedMessages).toContain(queryTool[0]); + }); + + it('should collapse the "execute_query" calls into the "query" tool response', () => { + const queryToolResponse = collapsedMessages.find( + (msg) => msg.message.role === MessageRole.User && msg.message.name === 'query' + )!; + + const content = JSON.parse(queryToolResponse.message.content!); + + expect(content.steps).toHaveLength(2); + expect(content.steps[0].role).toBe('assistant'); + expect(content.steps[1].role).toBe('tool'); + expect(content.steps[1].name).toBe('execute_query'); + expect(content.steps).toEqual([ + { + content: null, + role: 'assistant', + toolCalls: [ + { + function: { arguments: { query: 'FROM logs' }, name: 'execute_query' }, + toolCallId: expect.any(String), + }, + ], + }, + { + name: 'execute_query', + response: { + columns: [{ id: 'unique_ips', meta: { type: 'number' }, name: 'unique_ips' }], + rows: [[324567]], + }, + role: 'tool', + toolCallId: expect.any(String), + }, + ]); + }); + }); + + describe('when a query message is followed by "visualize_query" tool pair', () => { + let collapsedMessages: Message[]; + let messages: Message[]; + beforeEach(() => { + messages = [ + userMessage('Please visualize my logs'), + ...getDatasetInfoTool, + ...queryTool, + ...visualizeQueryTool, + assistantMessage('Here is the result'), + userMessage('What about the unique IPs?'), + ]; + + collapsedMessages = collapseInternalToolCalls(messages); + }); + + it('should collapse "visualize_query" into the "query" response', () => { + const queryToolResponse = collapsedMessages.find( + (msg) => msg.message.role === MessageRole.User && msg.message.name === 'query' + )!; + + const steps = JSON.parse(queryToolResponse.message.content!).steps as [ + InferenceMessage, + InferenceMessage + ]; + + const [toolCallRequest, toolCallResponse] = steps; + + expect(steps).toHaveLength(2); + + // @ts-expect-error + expect(toolCallRequest.toolCalls[0].function.name).toContain('visualize_query'); + expect(toolCallRequest.role).toContain('assistant'); + + // @ts-expect-error + expect(toolCallResponse.name).toContain('visualize_query'); + expect(toolCallResponse.role).toContain('tool'); + }); + }); + + describe('when an unrelated tool call is present', () => { + let collapsedMessages: Message[]; + beforeEach(() => { + const messages: Message[] = [ + ...queryTool, + ...executeQueryTool, + { + '@timestamp': '2025-07-02T10:02:00Z', + message: { + role: MessageRole.Assistant, + function_call: { + name: 'some_other_function', + arguments: JSON.stringify({ user: 'george' }), + trigger: MessageRole.Assistant, + }, + content: undefined, + }, + }, + ]; + collapsedMessages = collapseInternalToolCalls(messages); + }); + + it('should stop collapsing and preserve the unrelated tool call', () => { + expect(collapsedMessages).toHaveLength(3); + }); + + it('should add "execute_query" to the "query" response', () => { + const queryToolResponse = collapsedMessages[1]; + expect(queryToolResponse.message.content).toContain('execute_query'); + }); + + it('should retain the unrelated tool call as the last message', () => { + expect(collapsedMessages[2].message.function_call?.name).toEqual('some_other_function'); + }); + }); +}); + +describe('convertMessagesForInference', () => {}); diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts index 229183ed142a7..fe7ad87a46720 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts @@ -11,12 +11,51 @@ import { MessageRole as InferenceMessageRole, } from '@kbn/inference-common'; import { generateFakeToolCallId } from '@kbn/inference-plugin/common'; +import { takeWhile } from 'lodash'; import { Message, MessageRole } from '.'; +export function collapseInternalToolCalls(messages: Message[]) { + const collapsed: Message[] = []; + + for (let i = 0; i < messages.length; i++) { + const message = messages[i]; + + if (message.message.role === MessageRole.User && message.message.name === 'query') { + const messagesToCollapse = takeWhile(messages.slice(i + 1), (msg) => { + const name = msg.message.name || msg.message.function_call?.name; + return name && ['query', 'visualize_query', 'execute_query'].includes(name); + }); + + if (messagesToCollapse.length) { + const content = JSON.parse(message.message.content!); + collapsed.push({ + ...message, + message: { + ...message.message, + content: JSON.stringify({ + ...content, + steps: convertMessagesForInference(messagesToCollapse), + }), + }, + }); + + i += messagesToCollapse.length; + continue; + } + } + + collapsed.push(message); + } + + return collapsed; +} + export function convertMessagesForInference(messages: Message[]): InferenceMessage[] { const inferenceMessages: InferenceMessage[] = []; - messages.forEach((message) => { + const collapsedMessages: Message[] = collapseInternalToolCalls(messages); + + collapsedMessages.forEach((message, idx) => { if (message.message.role === MessageRole.Assistant) { inferenceMessages.push({ role: InferenceMessageRole.Assistant, diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts index 1d5c7a3ccf12f..2d9733f490d2e 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.test.ts @@ -8,7 +8,7 @@ import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client'; import type { CoreSetup, ElasticsearchClient, IUiSettingsClient, Logger } from '@kbn/core/server'; import type { DeeplyMockedKeys } from '@kbn/utility-types-jest'; import { waitFor } from '@testing-library/react'; -import { last, merge, repeat } from 'lodash'; +import { isEmpty, last, merge, repeat, size } from 'lodash'; import { Subject, Observable } from 'rxjs'; import { EventEmitter, type Readable } from 'stream'; import { finished } from 'stream/promises'; @@ -139,6 +139,7 @@ describe('Observability AI Assistant client', () => { // uncomment this line for debugging // const consoleOrPassThrough = console.log.bind(console); + const consoleOrPassThrough = () => {}; loggerMock = { @@ -282,6 +283,8 @@ describe('Observability AI Assistant client', () => { system: 'You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you.', functionCalling: 'auto', + maxRetries: 0, + temperature: 0.25, toolChoice: expect.objectContaining({ function: 'title_conversation', }), @@ -319,6 +322,8 @@ describe('Observability AI Assistant client', () => { { role: 'user', content: 'How many alerts do I have?' }, ]), functionCalling: 'auto', + maxRetries: 0, + temperature: 0.25, toolChoice: undefined, tools: undefined, metadata: { @@ -841,6 +846,8 @@ describe('Observability AI Assistant client', () => { { role: 'user', content: 'How many alerts do I have?' }, ]), functionCalling: 'auto', + maxRetries: 0, + temperature: 0.25, toolChoice: 'auto', tools: expect.any(Object), metadata: { @@ -995,6 +1002,8 @@ describe('Observability AI Assistant client', () => { { role: 'user', content: 'How many alerts do I have?' }, ]), functionCalling: 'auto', + maxRetries: 0, + temperature: 0.25, toolChoice: 'auto', tools: expect.any(Object), metadata: { @@ -1257,12 +1266,14 @@ describe('Observability AI Assistant client', () => { ]); functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts'); - functionClientMock.executeFunction.mockImplementation(async () => ({ - content: 'Call this function again', - })); + functionClientMock.executeFunction.mockImplementation(async () => { + return { + content: 'Call this function again', + }; + }); stream = observableIntoStream( - await client.complete({ + client.complete({ connectorId: 'foo', messages: [user('How many alerts do I have?')], functionClient: functionClientMock, @@ -1280,7 +1291,7 @@ describe('Observability AI Assistant client', () => { const body = inferenceClientMock.chatComplete.mock.lastCall![0]; let nextLlmCallPromise: Promise; - if (Object.keys(body.tools ?? {}).length) { + if (!isEmpty(body.tools) && body.tools.exit_loop === undefined) { nextLlmCallPromise = waitForNextLlmCall(); await llmSimulator.chunk({ function_call: { name: 'get_top_alerts', arguments: '{}' } }); } else { @@ -1310,9 +1321,19 @@ describe('Observability AI Assistant client', () => { const firstBody = inferenceClientMock.chatComplete.mock.calls[0][0] as any; const body = inferenceClientMock.chatComplete.mock.lastCall![0] as any; - expect(Object.keys(firstBody.tools ?? {}).length).toEqual(1); + expect(size(firstBody.tools)).toEqual(1); - expect(body.tools).toEqual(undefined); + expect(body.tools).toEqual({ + exit_loop: { + description: + "You've run out of tool calls. Call this tool, and explain to the user you've run out of budget.", + schema: { + properties: { response: { description: 'Your textual response', type: 'string' } }, + required: ['response'], + type: 'object', + }, + }, + }); }); }); diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts index adec401ce243a..b8a36f886c12e 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/index.ts @@ -479,7 +479,9 @@ export class ObservabilityAIAssistantClient { return defer(() => this.dependencies.inferenceClient.chatComplete({ ...options, + temperature: 0.25, stream: true, + maxRetries: 0, }) ).pipe( convertInferenceEventsToStreamingEvents(), @@ -501,6 +503,8 @@ export class ObservabilityAIAssistantClient { return this.dependencies.inferenceClient.chatComplete({ ...options, stream: false, + temperature: 0.25, + maxRetries: 0, }) as TStream extends true ? never : Promise; } } diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts index a1aeee5aafa82..5c3c5eac405e2 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/service/client/operators/continue_conversation.ts @@ -7,30 +7,38 @@ import { Logger } from '@kbn/logging'; import { decode, encode } from 'gpt-tokenizer'; -import { last, pick, take } from 'lodash'; +import { last, omit, pick, take } from 'lodash'; import { + EMPTY, + Observable, + OperatorFunction, catchError, concat, - EMPTY, from, isObservable, - Observable, + map, of, - OperatorFunction, shareReplay, switchMap, throwError, } from 'rxjs'; -import { CONTEXT_FUNCTION_NAME } from '../../../functions/context'; -import { createFunctionNotFoundError, Message, MessageRole } from '../../../../common'; import { - createFunctionLimitExceededError, + CompatibleJSONSchema, + Message, + MessageAddEvent, + MessageRole, + StreamingChatResponseEventType, + createFunctionNotFoundError, +} from '../../../../common'; +import { MessageOrChatEvent, + createFunctionLimitExceededError, } from '../../../../common/conversation_complete'; import { FunctionVisibility } from '../../../../common/functions/types'; import { Instruction } from '../../../../common/types'; import { createFunctionResponseMessage } from '../../../../common/utils/create_function_response_message'; import { emitWithConcatenatedMessage } from '../../../../common/utils/emit_with_concatenated_message'; +import { CONTEXT_FUNCTION_NAME } from '../../../functions/context'; import type { ChatFunctionClient } from '../../chat_function_client'; import type { AutoAbortedChatFunction } from '../../types'; import { createServerSideFunctionResponseError } from '../../util/create_server_side_function_response_error'; @@ -40,6 +48,8 @@ import { extractMessages } from './extract_messages'; const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000; +const EXIT_LOOP_FUNCTION_NAME = 'exit_loop'; + function executeFunctionAndCatchError({ name, args, @@ -128,17 +138,42 @@ function executeFunctionAndCatchError({ }); } -function getFunctionDefinitions({ +function getFunctionOptions({ functionClient, - functionLimitExceeded, disableFunctions, + functionLimitExceeded, }: { functionClient: ChatFunctionClient; - functionLimitExceeded: boolean; disableFunctions: boolean; -}) { - if (functionLimitExceeded || disableFunctions === true) { - return []; + functionLimitExceeded: boolean; +}): { + functions?: Array<{ name: string; description: string; parameters?: CompatibleJSONSchema }>; + functionCall?: string; +} { + if (disableFunctions === true) { + return {}; + } + + if (functionLimitExceeded) { + return { + functionCall: EXIT_LOOP_FUNCTION_NAME, + functions: [ + { + name: EXIT_LOOP_FUNCTION_NAME, + description: `You've run out of tool calls. Call this tool, and explain to the user you've run out of budget.`, + parameters: { + type: 'object', + properties: { + response: { + type: 'string', + description: 'Your textual response', + }, + }, + required: ['response'], + }, + }, + ], + }; } const systemFunctions = functionClient @@ -156,7 +191,7 @@ function getFunctionDefinitions({ .concat(actions) .map((definition) => pick(definition, 'name', 'description', 'parameters')); - return allDefinitions; + return { functions: allDefinitions }; } export function continueConversation({ @@ -190,10 +225,10 @@ export function continueConversation({ const functionLimitExceeded = functionCallsLeft <= 0; - const definitions = getFunctionDefinitions({ - functionLimitExceeded, + const functionOptions = getFunctionOptions({ functionClient, disableFunctions, + functionLimitExceeded, }); const lastMessage = last(initialMessages)?.message; @@ -210,10 +245,10 @@ export function continueConversation({ return chat(operationName, { messages: initialMessages, - functions: definitions, tracer, connectorId, stream: true, + ...functionOptions, }).pipe(emitWithConcatenatedMessage(), catchFunctionNotFoundError(functionLimitExceeded)); } @@ -295,7 +330,32 @@ export function continueConversation({ function handleEvents(): OperatorFunction { return (events$) => { - const shared$ = events$.pipe(shareReplay()); + const shared$ = events$.pipe( + shareReplay(), + map((event) => { + if (event.type === StreamingChatResponseEventType.MessageAdd) { + const message = event.message; + + if (message.message.function_call?.name === EXIT_LOOP_FUNCTION_NAME) { + const args = JSON.parse(message.message.function_call.arguments ?? '{}') as { + response: string; + }; + + return { + ...event, + message: { + ...message, + message: { + ...omit(message.message, 'function_call', 'content'), + content: args.response ?? `The model returned an empty response`, + }, + }, + } satisfies MessageAddEvent; + } + } + return event; + }) + ); return concat( shared$, diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts index 7d4838f0b0362..b91fcc132593c 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts @@ -48,10 +48,7 @@ export function registerQueryFunction({ If the user asks for a query, and one of the dataset info functions was called and returned no results, you should still call the query function to generate an example query. Even if the "${QUERY_FUNCTION_NAME}" function was used before that, follow it up with the "${QUERY_FUNCTION_NAME}" function. If a query fails, do not attempt to correct it yourself. Again you should call the "${QUERY_FUNCTION_NAME}" function, - even if it has been called before. - - When the "visualize_query" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "visualize_query" function call with your own visualization attempt. - If the "${EXECUTE_QUERY_NAME}" function has been called, summarize these results for the user. The user does not see a visualization in this case.`; + even if it has been called before.`; }); functions.registerFunction( @@ -126,13 +123,15 @@ export function registerQueryFunction({ const actions = functions.getActions(); + const inferenceMessages = convertMessagesForInference( + // remove system message and query function request + messages.filter((message) => message.message.role !== MessageRole.System).slice(0, -1) + ); + const events$ = naturalLanguageToEsql({ client: pluginsStart.inference.getClient({ request: resources.request }), connectorId, - messages: convertMessagesForInference( - // remove system message and query function request - messages.filter((message) => message.message.role !== MessageRole.System).slice(0, -1) - ), + messages: inferenceMessages, logger: resources.logger, tools: Object.fromEntries( [...actions, ...esqlFunctions].map((fn) => [ @@ -141,6 +140,7 @@ export function registerQueryFunction({ ]) ), functionCalling: simulateFunctionCalling ? 'simulated' : 'auto', + maxRetries: 0, metadata: { connectorTelemetry: { pluginId: 'observability_ai_assistant', diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts index a389bc83a711e..e43e71adfabab 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/execute_query.spec.ts @@ -10,7 +10,7 @@ import expect from '@kbn/expect'; import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace'; import { last } from 'lodash'; import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; -import { EsqlResponse } from '@elastic/elasticsearch/lib/helpers'; +import { EsqlToRecords } from '@elastic/elasticsearch/lib/helpers'; import { LlmProxy, createLlmProxy, @@ -167,13 +167,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); }); - describe('The fourth request - Executing the ES|QL query', () => { - it('contains the `execute_query` tool call request', () => { + describe('The fourth request - executing the ES|QL query', () => { + it('should not contain the `execute_query` tool call request', () => { const hasToolCall = fourthRequestBody.messages.some( // @ts-expect-error (message) => message.tool_calls?.[0]?.function?.name === 'execute_query' ); - expect(hasToolCall).to.be(true); + expect(hasToolCall).to.be(false); }); it('emits a messageAdded event with the `execute_query` tool response', () => { @@ -183,25 +183,56 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon expect(event?.message.message.content).to.contain('simple log message'); }); - describe('the `execute_query` tool call response', () => { - let toolCallResponse: { columns: EsqlResponse['columns']; rows: EsqlResponse['values'] }; - before(async () => { - toolCallResponse = JSON.parse(last(fourthRequestBody.messages)?.content as string); - }); + describe('tool call collapsing', () => { + it('collapses the `execute_query` tool call into the `query` tool response', () => { + const content = JSON.parse(last(fourthRequestBody.messages)?.content as string); + expect(content.steps).to.have.length(2); + + const [toolRequest, toolResponse] = content.steps; + + // visualize_query tool request (sent by the LLM) + expect(toolRequest.role).to.be('assistant'); + expect(toolRequest.toolCalls[0].function.name).to.be('execute_query'); - it('has the correct columns', () => { - expect(toolCallResponse.columns.map(({ name }) => name)).to.eql([ - 'message', - '@timestamp', - ]); + // visualize_query tool response (sent by AI Assistant) + expect(toolResponse.role).to.be('tool'); + expect(toolResponse.name).to.be('execute_query'); }); - it('has the correct number of rows', () => { - expect(toolCallResponse.rows.length).to.be(10); + it('contains the `execute_query` tool call request', () => { + const toolCallRequest = JSON.parse(last(fourthRequestBody.messages)?.content as string) + .steps[0].toolCalls[0]; + expect(toolCallRequest.function.name).to.be('execute_query'); + expect(toolCallRequest.function.arguments.query).to.contain( + 'FROM logs-apache.access-default' + ); }); - it('has the right log message', () => { - expect(toolCallResponse.rows[0][0]).to.be('simple log message'); + describe('the `execute_query` response', () => { + let toolCallResponse: { + columns: EsqlToRecords['columns']; + rows: EsqlToRecords['records']; + }; + + before(async () => { + toolCallResponse = JSON.parse(last(fourthRequestBody.messages)?.content as string) + .steps[1].response; + }); + + it('has the correct columns', () => { + expect(toolCallResponse.columns.map(({ name }) => name)).to.eql([ + 'message', + '@timestamp', + ]); + }); + + it('has the correct number of rows', () => { + expect(toolCallResponse.rows.length).to.be(10); + }); + + it('has the right log message', () => { + expect(toolCallResponse.rows[0][0]).to.be('simple log message'); + }); }); }); });