diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts index 0f73862077589..2190f9dc081df 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/query/index.ts @@ -5,9 +5,17 @@ * 2.0. */ -import { ToolDefinition, isChatCompletionChunkEvent, isOutputEvent } from '@kbn/inference-common'; +import { map } from 'rxjs'; +import { v4 } from 'uuid'; +import { + ToolDefinition, + ToolChoice, + isChatCompletionChunkEvent, + isOutputEvent, +} from '@kbn/inference-common'; import { correctCommonEsqlMistakes } from '@kbn/inference-plugin/common'; import { naturalLanguageToEsql } from '@kbn/inference-plugin/server'; +import { safeJsonParse } from '@kbn/std'; import { MessageAddEvent, MessageRole, @@ -15,8 +23,6 @@ import { } from '@kbn/observability-ai-assistant-plugin/common'; import { createFunctionResponseMessage } from '@kbn/observability-ai-assistant-plugin/common/utils/create_function_response_message'; import { convertMessagesForInference } from '@kbn/observability-ai-assistant-plugin/common/convert_messages_for_inference'; -import { map } from 'rxjs'; -import { v4 } from 'uuid'; import { VISUALIZE_QUERY_NAME } from '../../../common/functions/visualize_esql'; import type { FunctionRegistrationParameters } from '..'; import { runAndValidateEsqlQuery } from './validate_esql_query'; @@ -24,6 +30,9 @@ import { runAndValidateEsqlQuery } from './validate_esql_query'; export const QUERY_FUNCTION_NAME = 'query'; export const EXECUTE_QUERY_NAME = 'execute_query'; +export const QUERY_INTENT_VALUES = ['example', 'data', 'visual'] as const; +export type QueryIntent = (typeof QUERY_INTENT_VALUES)[number]; + export function registerQueryFunction({ functions, resources, @@ -56,7 +65,7 @@ export function registerQueryFunction({ name: EXECUTE_QUERY_NAME, isInternal: true, description: `Execute a generated ES|QL query on behalf of the user. The results - will be returned to you. + will be returned to you. You must use this function if the user is asking for the result of a query, such as a metric or list of things, but does not want to visualize it in @@ -111,6 +120,18 @@ export function registerQueryFunction({ convert queries from one language to another. Make sure you call one of the get_dataset functions first if you need index or field names. This function takes no input.`, + parameters: { + type: 'object', + properties: { + queryIntent: { + type: 'string', + enum: QUERY_INTENT_VALUES, + description: + 'Controls how the query function behaves: generate query only, execute the query, or visualize results', + }, + }, + required: ['queryIntent'], + } as const, }, async ({ messages, connectorId, simulateFunctionCalling }) => { const esqlFunctions = functions @@ -123,23 +144,51 @@ export function registerQueryFunction({ const actions = functions.getActions(); + // Remove system messages + const nonSystemMessages = messages.filter((msg) => msg.message.role !== MessageRole.System); + + const queryRequestMessage = nonSystemMessages[nonSystemMessages.length - 1]; + + // Extract query intent argument + let queryIntent: QueryIntent | undefined; + if (queryRequestMessage?.message?.function_call?.arguments) { + const args = safeJsonParse<{ queryIntent?: QueryIntent }>( + queryRequestMessage.message.function_call.arguments + ); + queryIntent = args?.queryIntent; + } + const inferenceMessages = convertMessagesForInference( - // remove system message and query function request - messages.filter((message) => message.message.role !== MessageRole.System).slice(0, -1), + // Remove query function request + [...nonSystemMessages.slice(0, -1)], resources.logger ); + // decide toolChoice based on queryIntent + let toolChoice: ToolChoice | undefined; + if (queryIntent === 'data') { + toolChoice = { function: EXECUTE_QUERY_NAME }; + } else if (queryIntent === 'visual') { + toolChoice = { function: VISUALIZE_QUERY_NAME }; + } + + // drop query execution/visualization when only an example is requested + const esqlToolDefinitions = queryIntent === 'example' ? [] : esqlFunctions; + + const availableToolDefinitions = Object.fromEntries( + [...actions, ...esqlToolDefinitions].map((fn) => [ + fn.name, + { description: fn.description, schema: fn.parameters } as ToolDefinition, + ]) + ); + const events$ = naturalLanguageToEsql({ client: pluginsStart.inference.getClient({ request: resources.request }), connectorId, messages: inferenceMessages, logger: resources.logger, - tools: Object.fromEntries( - [...actions, ...esqlFunctions].map((fn) => [ - fn.name, - { description: fn.description, schema: fn.parameters } as ToolDefinition, - ]) - ), + tools: availableToolDefinitions, + toolChoice, functionCalling: simulateFunctionCalling ? 'simulated' : 'auto', maxRetries: 0, metadata: { diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/tsconfig.json b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/tsconfig.json index 5c7fb703ab32f..02a1d5507cbdf 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/tsconfig.json +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/tsconfig.json @@ -85,6 +85,7 @@ "@kbn/i18n-react", "@kbn/utility-types", "@kbn/alerts-ui-shared", + "@kbn/std", "@kbn/traced-es-client" ], "exclude": ["target/**/*"]