Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,34 @@
* 2.0.
*/

import { ToolDefinition, isChatCompletionChunkEvent, isOutputEvent } from '@kbn/inference-common';
import { map } from 'rxjs';
import { v4 } from 'uuid';
import {
ToolDefinition,
ToolChoice,
isChatCompletionChunkEvent,
isOutputEvent,
} from '@kbn/inference-common';
import { correctCommonEsqlMistakes } from '@kbn/inference-plugin/common';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import { safeJsonParse } from '@kbn/std';
import {
MessageAddEvent,
MessageRole,
StreamingChatResponseEventType,
} from '@kbn/observability-ai-assistant-plugin/common';
import { createFunctionResponseMessage } from '@kbn/observability-ai-assistant-plugin/common/utils/create_function_response_message';
import { convertMessagesForInference } from '@kbn/observability-ai-assistant-plugin/common/convert_messages_for_inference';
import { map } from 'rxjs';
import { v4 } from 'uuid';
import { VISUALIZE_QUERY_NAME } from '../../../common/functions/visualize_esql';
import type { FunctionRegistrationParameters } from '..';
import { runAndValidateEsqlQuery } from './validate_esql_query';

export const QUERY_FUNCTION_NAME = 'query';
export const EXECUTE_QUERY_NAME = 'execute_query';

export const QUERY_INTENT_VALUES = ['example', 'data', 'visual'] as const;
export type QueryIntent = (typeof QUERY_INTENT_VALUES)[number];

export function registerQueryFunction({
functions,
resources,
Expand Down Expand Up @@ -56,7 +65,7 @@ export function registerQueryFunction({
name: EXECUTE_QUERY_NAME,
isInternal: true,
description: `Execute a generated ES|QL query on behalf of the user. The results
will be returned to you.
will be returned to you.

You must use this function if the user is asking for the result of a query,
such as a metric or list of things, but does not want to visualize it in
Expand Down Expand Up @@ -111,6 +120,18 @@ export function registerQueryFunction({
convert queries from one language to another. Make sure you call one of
the get_dataset functions first if you need index or field names. This
function takes no input.`,
parameters: {
type: 'object',
properties: {
queryIntent: {
type: 'string',
enum: QUERY_INTENT_VALUES,
description:
'Controls how the query function behaves: generate query only, execute the query, or visualize results',
},
},
required: ['queryIntent'],
} as const,
},
async ({ messages, connectorId, simulateFunctionCalling }) => {
const esqlFunctions = functions
Expand All @@ -123,23 +144,51 @@ export function registerQueryFunction({

const actions = functions.getActions();

// Remove system messages
const nonSystemMessages = messages.filter((msg) => msg.message.role !== MessageRole.System);

const queryRequestMessage = nonSystemMessages[nonSystemMessages.length - 1];

// Extract query intent argument
let queryIntent: QueryIntent | undefined;
if (queryRequestMessage?.message?.function_call?.arguments) {
const args = safeJsonParse<{ queryIntent?: QueryIntent }>(
queryRequestMessage.message.function_call.arguments
);
queryIntent = args?.queryIntent;
}

const inferenceMessages = convertMessagesForInference(
// remove system message and query function request
messages.filter((message) => message.message.role !== MessageRole.System).slice(0, -1),
// Remove query function request
[...nonSystemMessages.slice(0, -1)],
resources.logger
);

// decide toolChoice based on queryIntent
let toolChoice: ToolChoice<string> | undefined;
if (queryIntent === 'data') {
toolChoice = { function: EXECUTE_QUERY_NAME };
} else if (queryIntent === 'visual') {
toolChoice = { function: VISUALIZE_QUERY_NAME };
}

// drop query execution/visualization when only an example is requested
const esqlToolDefinitions = queryIntent === 'example' ? [] : esqlFunctions;

const availableToolDefinitions = Object.fromEntries(
[...actions, ...esqlToolDefinitions].map((fn) => [
fn.name,
{ description: fn.description, schema: fn.parameters } as ToolDefinition,
])
);

const events$ = naturalLanguageToEsql({
client: pluginsStart.inference.getClient({ request: resources.request }),
connectorId,
messages: inferenceMessages,
logger: resources.logger,
tools: Object.fromEntries(
[...actions, ...esqlFunctions].map((fn) => [
fn.name,
{ description: fn.description, schema: fn.parameters } as ToolDefinition,
])
),
tools: availableToolDefinitions,
toolChoice,
functionCalling: simulateFunctionCalling ? 'simulated' : 'auto',
maxRetries: 0,
metadata: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"@kbn/i18n-react",
"@kbn/utility-types",
"@kbn/alerts-ui-shared",
"@kbn/std",
"@kbn/traced-es-client"
],
"exclude": ["target/**/*"]
Expand Down