From c8043c9a9c966a95073a8138a373135f3e01b827 Mon Sep 17 00:00:00 2001 From: Kibana Machine <42973632+kibanamachine@users.noreply.github.com> Date: Wed, 18 Sep 2024 22:17:27 +1000 Subject: [PATCH] [8.x] [inference] Add simulated function calling (#192544) (#193275) # Backport This will backport the following commits from `main` to `8.x`: - [[inference] Add simulated function calling (#192544)](https://github.com/elastic/kibana/pull/192544) ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) Co-authored-by: Pierre Gayvallet --- .../inference/common/chat_complete/index.ts | 3 + .../inference/common/chat_complete/request.ts | 5 +- .../common/output/create_output_api.ts | 7 +- .../plugins/inference/common/output/index.ts | 3 +- .../inference/public/chat_complete/index.ts | 3 +- .../adapters/openai/openai_adapter.ts | 41 ++++-- .../inference/server/chat_complete/api.ts | 2 + .../simulated_function_calling/constants.ts | 9 ++ .../get_system_instructions.ts | 84 +++++++++++ .../simulated_function_calling/index.ts | 9 ++ .../parse_inline_function_calls.ts | 136 ++++++++++++++++++ .../wrap_with_simulated_function_calling.ts | 106 ++++++++++++++ .../inference/server/chat_complete/types.ts | 4 +- .../inference/server/routes/chat_complete.ts | 6 +- .../tasks/nl_to_esql/actions/generate_esql.ts | 4 + .../actions/request_documentation.ts | 4 + .../tasks/nl_to_esql/doc_base/aliases.ts | 3 +- .../tasks/nl_to_esql/esql_docs/esql-where.txt | 1 + .../tasks/nl_to_esql/system_message.txt | 1 - .../inference/server/tasks/nl_to_esql/task.ts | 3 + .../server/tasks/nl_to_esql/types.ts | 2 + .../chat_function_client/index.test.ts | 2 + .../service/chat_function_client/index.ts | 3 + .../server/service/client/index.test.ts | 1 + .../server/service/client/index.ts | 3 +- .../client/operators/continue_conversation.ts | 7 + .../server/service/types.ts | 1 + .../common/functions/visualize_esql.ts | 11 +- .../public/functions/visualize_esql.tsx | 7 +- .../server/functions/query/index.ts | 13 +- .../functions/query/validate_esql_query.ts | 10 +- .../server/functions/visualize_esql.ts | 11 +- 32 files changed, 472 insertions(+), 33 deletions(-) create mode 100644 x-pack/plugins/inference/server/chat_complete/simulated_function_calling/constants.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/simulated_function_calling/get_system_instructions.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/simulated_function_calling/index.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/simulated_function_calling/parse_inline_function_calls.ts create mode 100644 x-pack/plugins/inference/server/chat_complete/simulated_function_calling/wrap_with_simulated_function_calling.ts diff --git a/x-pack/plugins/inference/common/chat_complete/index.ts b/x-pack/plugins/inference/common/chat_complete/index.ts index b42c2217c01776..aef9de12ba7a9d 100644 --- a/x-pack/plugins/inference/common/chat_complete/index.ts +++ b/x-pack/plugins/inference/common/chat_complete/index.ts @@ -78,6 +78,8 @@ export type ChatCompletionEvent | ChatCompletionTokenCountEvent | ChatCompletionMessageEvent; +export type FunctionCallingMode = 'native' | 'simulated'; + /** * Request a completion from the LLM based on a prompt or conversation. * @@ -92,5 +94,6 @@ export type ChatCompleteAPI = ( connectorId: string; system?: string; messages: Message[]; + functionCalling?: FunctionCallingMode; } & TToolOptions ) => ChatCompletionResponse; diff --git a/x-pack/plugins/inference/common/chat_complete/request.ts b/x-pack/plugins/inference/common/chat_complete/request.ts index 104d1856c9c808..1038e481a6260b 100644 --- a/x-pack/plugins/inference/common/chat_complete/request.ts +++ b/x-pack/plugins/inference/common/chat_complete/request.ts @@ -5,12 +5,13 @@ * 2.0. */ -import type { Message } from '.'; -import { ToolOptions } from './tools'; +import type { Message, FunctionCallingMode } from '.'; +import type { ToolOptions } from './tools'; export type ChatCompleteRequestBody = { connectorId: string; stream?: boolean; system?: string; messages: Message[]; + functionCalling?: FunctionCallingMode; } & ToolOptions; diff --git a/x-pack/plugins/inference/common/output/create_output_api.ts b/x-pack/plugins/inference/common/output/create_output_api.ts index 35fc2b3647004d..848135beefb0f8 100644 --- a/x-pack/plugins/inference/common/output/create_output_api.ts +++ b/x-pack/plugins/inference/common/output/create_output_api.ts @@ -12,10 +12,11 @@ import { OutputAPI, OutputEvent, OutputEventType } from '.'; import { ensureMultiTurn } from '../ensure_multi_turn'; export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI { - return (id, { connectorId, input, schema, system, previousMessages }) => { + return (id, { connectorId, input, schema, system, previousMessages, functionCalling }) => { return chatCompleteApi({ connectorId, system, + functionCalling, messages: ensureMultiTurn([ ...(previousMessages || []), { @@ -26,12 +27,12 @@ export function createOutputApi(chatCompleteApi: ChatCompleteAPI): OutputAPI { ...(schema ? { tools: { - output: { + structuredOutput: { description: `Use the following schema to respond to the user's request in structured data, so it can be parsed and handled.`, schema, }, }, - toolChoice: { function: 'output' as const }, + toolChoice: { function: 'structuredOutput' as const }, } : {}), }).pipe( diff --git a/x-pack/plugins/inference/common/output/index.ts b/x-pack/plugins/inference/common/output/index.ts index d7522f2cfa52e0..0f7655f8f1cd4b 100644 --- a/x-pack/plugins/inference/common/output/index.ts +++ b/x-pack/plugins/inference/common/output/index.ts @@ -8,7 +8,7 @@ import { Observable } from 'rxjs'; import { ServerSentEventBase } from '@kbn/sse-utils'; import { FromToolSchema, ToolSchema } from '../chat_complete/tool_schema'; -import { Message } from '../chat_complete'; +import type { Message, FunctionCallingMode } from '../chat_complete'; export enum OutputEventType { OutputUpdate = 'output', @@ -61,6 +61,7 @@ export type OutputAPI = < input: string; schema?: TOutputSchema; previousMessages?: Message[]; + functionCalling?: FunctionCallingMode; } ) => Observable< OutputEvent : undefined> diff --git a/x-pack/plugins/inference/public/chat_complete/index.ts b/x-pack/plugins/inference/public/chat_complete/index.ts index 3dfe4616b7323c..e229f6c8f8eaec 100644 --- a/x-pack/plugins/inference/public/chat_complete/index.ts +++ b/x-pack/plugins/inference/public/chat_complete/index.ts @@ -12,13 +12,14 @@ import type { ChatCompleteRequestBody } from '../../common/chat_complete/request import { httpResponseIntoObservable } from '../util/http_response_into_observable'; export function createChatCompleteApi({ http }: { http: HttpStart }): ChatCompleteAPI { - return ({ connectorId, messages, system, toolChoice, tools }) => { + return ({ connectorId, messages, system, toolChoice, tools, functionCalling }) => { const body: ChatCompleteRequestBody = { connectorId, system, messages, toolChoice, tools, + functionCalling, }; return from( diff --git a/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.ts b/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.ts index 62af864a6037dc..f1821be4d4d571 100644 --- a/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.ts +++ b/x-pack/plugins/inference/server/chat_complete/adapters/openai/openai_adapter.ts @@ -13,7 +13,7 @@ import type { ChatCompletionToolMessageParam, ChatCompletionUserMessageParam, } from 'openai/resources'; -import { filter, from, map, switchMap, tap, throwError } from 'rxjs'; +import { filter, from, map, switchMap, tap, throwError, identity } from 'rxjs'; import { Readable, isReadable } from 'stream'; import { ChatCompletionChunkEvent, @@ -26,18 +26,38 @@ import { createTokenLimitReachedError } from '../../../../common/chat_complete/e import { createInferenceInternalError } from '../../../../common/errors'; import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable'; import type { InferenceConnectorAdapter } from '../../types'; +import { + wrapWithSimulatedFunctionCalling, + parseInlineFunctionCalls, +} from '../../simulated_function_calling'; export const openAIAdapter: InferenceConnectorAdapter = { - chatComplete: ({ executor, system, messages, toolChoice, tools }) => { + chatComplete: ({ executor, system, messages, toolChoice, tools, functionCalling, logger }) => { const stream = true; + const simulatedFunctionCalling = functionCalling === 'simulated'; - const request: Omit & { model?: string } = { - stream, - messages: messagesToOpenAI({ system, messages }), - tool_choice: toolChoiceToOpenAI(toolChoice), - tools: toolsToOpenAI(tools), - temperature: 0, - }; + let request: Omit & { model?: string }; + if (simulatedFunctionCalling) { + const wrapped = wrapWithSimulatedFunctionCalling({ + system, + messages, + toolChoice, + tools, + }); + request = { + stream, + messages: messagesToOpenAI({ system: wrapped.system, messages: wrapped.messages }), + temperature: 0, + }; + } else { + request = { + stream, + messages: messagesToOpenAI({ system, messages }), + tool_choice: toolChoiceToOpenAI(toolChoice), + tools: toolsToOpenAI(tools), + temperature: 0, + }; + } return from( executor.invoke({ @@ -94,7 +114,8 @@ export const openAIAdapter: InferenceConnectorAdapter = { }; }) ?? [], }; - }) + }), + simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity ); }, }; diff --git a/x-pack/plugins/inference/server/chat_complete/api.ts b/x-pack/plugins/inference/server/chat_complete/api.ts index fe879392cd4de8..ca9e61ff3627f2 100644 --- a/x-pack/plugins/inference/server/chat_complete/api.ts +++ b/x-pack/plugins/inference/server/chat_complete/api.ts @@ -31,6 +31,7 @@ export function createChatCompleteApi({ toolChoice, tools, system, + functionCalling, }): ChatCompletionResponse => { return defer(async () => { const actionsClient = await actions.getActionsClientWithRequest(request); @@ -58,6 +59,7 @@ export function createChatCompleteApi({ toolChoice, tools, logger, + functionCalling, }); }), chunksIntoMessage({ diff --git a/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/constants.ts b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/constants.ts new file mode 100644 index 00000000000000..a25deca07b7d93 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/constants.ts @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export const TOOL_USE_START = '<|tool_use_start|>'; +export const TOOL_USE_END = '<|tool_use_end|>'; diff --git a/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/get_system_instructions.ts b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/get_system_instructions.ts new file mode 100644 index 00000000000000..872e842e03f864 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/get_system_instructions.ts @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { TOOL_USE_END, TOOL_USE_START } from './constants'; +import { ToolDefinition } from '../../../common/chat_complete/tools'; + +export function getSystemMessageInstructions({ + tools, +}: { + tools?: Record; +}) { + const formattedTools = Object.entries(tools ?? {}).map(([name, tool]) => { + return { + name, + ...tool, + }; + }); + + if (formattedTools.length) { + return `In this environment, you have access to a set of tools you can use to answer the user's question. + + DO NOT call a tool when it is not listed. + ONLY define input that is defined in the tool properties. + If a tool does not have properties, leave them out. + + It is EXTREMELY important that you generate valid JSON between the \`\`\`json and \`\`\` delimiters. + + You may call them like this. + + Given the following tool: + + ${JSON.stringify({ + name: 'my_tool', + description: 'A tool to call', + schema: { + type: 'object', + properties: { + myProperty: { + type: 'string', + }, + }, + }, + })} + + Use it the following way: + + ${TOOL_USE_START} + \`\`\`json + ${JSON.stringify({ name: 'my_tool', input: { myProperty: 'myValue' } })} + \`\`\`\ + ${TOOL_USE_END} + + Given the following tool: + ${JSON.stringify({ + name: 'my_tool_without_parameters', + description: 'A tool to call without parameters', + })} + + Use it the following way: + ${TOOL_USE_START} + \`\`\`json + ${JSON.stringify({ name: 'my_tool_without_parameters', input: {} })} + \`\`\`\ + ${TOOL_USE_END} + + Here are the tools available: + + ${JSON.stringify( + formattedTools.map((tool) => ({ + name: tool.name, + description: tool.description, + ...(tool.schema ? { schema: tool.schema } : {}), + })) + )} + + `; + } + + return `No tools are available anymore. DO NOT UNDER ANY CIRCUMSTANCES call any tool, regardless of whether it was previously called.`; +} diff --git a/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/index.ts b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/index.ts new file mode 100644 index 00000000000000..8863628f8af681 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/index.ts @@ -0,0 +1,9 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { wrapWithSimulatedFunctionCalling } from './wrap_with_simulated_function_calling'; +export { parseInlineFunctionCalls } from './parse_inline_function_calls'; diff --git a/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/parse_inline_function_calls.ts b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/parse_inline_function_calls.ts new file mode 100644 index 00000000000000..2fa9dd899e9860 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/parse_inline_function_calls.ts @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Observable } from 'rxjs'; +import { Logger } from '@kbn/logging'; +import { + ChatCompletionChunkEvent, + ChatCompletionTokenCountEvent, + ChatCompletionEventType, +} from '../../../common/chat_complete'; +import { createInferenceInternalError } from '../../../common/errors'; +import { TOOL_USE_END, TOOL_USE_START } from './constants'; + +function matchOnSignalStart(buffer: string) { + if (buffer.includes(TOOL_USE_START)) { + const split = buffer.split(TOOL_USE_START); + return [split[0], TOOL_USE_START + split[1]]; + } + + for (let i = 0; i < buffer.length; i++) { + const remaining = buffer.substring(i); + if (TOOL_USE_START.startsWith(remaining)) { + return [buffer.substring(0, i), remaining]; + } + } + + return false; +} + +export function parseInlineFunctionCalls({ logger }: { logger: Logger }) { + return (source: Observable) => { + let functionCallBuffer: string = ''; + + // As soon as we see a TOOL_USE_START token, we write all chunks + // to a buffer, that we flush as a function request if we + // spot the stop sequence. + + return new Observable( + (subscriber) => { + function parseFunctionCall(buffer: string) { + logger.debug('Parsing function call:\n' + buffer); + + const match = buffer.match( + /<\|tool_use_start\|>\s*```json\n?(.*?)(\n```\s*).*<\|tool_use_end\|>/s + ); + + const functionCallBody = match?.[1]; + + if (!functionCallBody) { + throw createInferenceInternalError(`Invalid function call syntax`); + } + + const parsedFunctionCall = JSON.parse(functionCallBody) as { + name?: string; + input?: unknown; + }; + + logger.debug(() => 'Parsed function call:\n ' + JSON.stringify(parsedFunctionCall)); + + if (!parsedFunctionCall.name) { + throw createInferenceInternalError(`Missing name for tool use`); + } + + subscriber.next({ + content: '', + tool_calls: [ + { + index: 0, + toolCallId: parsedFunctionCall.name, + function: { + name: parsedFunctionCall.name, + arguments: JSON.stringify(parsedFunctionCall.input || {}), + }, + }, + ], + type: ChatCompletionEventType.ChatCompletionChunk, + }); + } + + source.subscribe({ + next: (event) => { + if (event.type === ChatCompletionEventType.ChatCompletionTokenCount) { + subscriber.next(event); + return; + } + + const { type, content } = event; + + function next(contentToEmit: string) { + subscriber.next({ + type, + content: contentToEmit, + tool_calls: [], + }); + } + + const match = matchOnSignalStart(functionCallBuffer + content); + + if (match) { + const [beforeStartSignal, afterStartSignal] = match; + functionCallBuffer = afterStartSignal; + if (beforeStartSignal) { + next(beforeStartSignal); + } + + if (functionCallBuffer.includes(TOOL_USE_END)) { + const [beforeEndSignal, afterEndSignal] = functionCallBuffer.split(TOOL_USE_END); + + try { + parseFunctionCall(beforeEndSignal + TOOL_USE_END); + functionCallBuffer = ''; + next(afterEndSignal); + } catch (error) { + subscriber.error(error); + } + } + } else { + functionCallBuffer = ''; + next(content); + } + }, + complete: () => { + subscriber.complete(); + }, + error: (error) => { + subscriber.error(error); + }, + }); + } + ); + }; +} diff --git a/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/wrap_with_simulated_function_calling.ts b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/wrap_with_simulated_function_calling.ts new file mode 100644 index 00000000000000..d8cfc373b66cc1 --- /dev/null +++ b/x-pack/plugins/inference/server/chat_complete/simulated_function_calling/wrap_with_simulated_function_calling.ts @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { AssistantMessage, Message, ToolMessage, UserMessage } from '../../../common'; +import { MessageRole } from '../../../common/chat_complete'; +import { ToolChoice, ToolChoiceType, ToolDefinition } from '../../../common/chat_complete/tools'; +import { TOOL_USE_END, TOOL_USE_START } from './constants'; +import { getSystemMessageInstructions } from './get_system_instructions'; + +function replaceFunctionsWithTools(content: string) { + return content.replaceAll(/(function)(s|[\s*\.])?(?!\scall)/g, (match, p1, p2) => { + return `tool${p2 || ''}`; + }); +} + +export function wrapWithSimulatedFunctionCalling({ + messages, + system, + tools, + toolChoice, +}: { + messages: Message[]; + system?: string; + tools?: Record; + toolChoice?: ToolChoice; +}): { messages: Message[]; system: string } { + const instructions = getSystemMessageInstructions({ + tools, + }); + + const wrappedSystem = system ? `${system}\n${instructions}` : instructions; + + const wrappedMessages = messages + .map((message) => { + if (message.role === MessageRole.Tool) { + return convertToolResponseMessage(message); + } + if (message.role === MessageRole.Assistant && message.toolCalls?.length) { + return convertToolCallMessage(message); + } + return message; + }) + .map((message) => { + return { + ...message, + content: message.content ? replaceFunctionsWithTools(message.content) : message.content, + }; + }); + + if (toolChoice) { + let selectionMessage; + if (typeof toolChoice === 'object') { + selectionMessage = `Remember, use the ${toolChoice.function} tool to answer this question.`; + } else if (toolChoice === ToolChoiceType.required) { + selectionMessage = `Remember, you MUST use one of the provided tool to answer this question.`; + } else if (toolChoice === ToolChoiceType.auto) { + selectionMessage = `Remember, you CAN use one of the provided tool to answer this question.`; + } + + if (selectionMessage) { + wrappedMessages[messages.length - 1].content += `\n${selectionMessage}`; + } + } + + return { + messages: wrappedMessages as Message[], + system: wrappedSystem, + }; +} + +const convertToolResponseMessage = (message: ToolMessage): UserMessage => { + return { + role: MessageRole.User, + content: JSON.stringify({ + type: 'tool_result', + tool: message.toolCallId, + response: message.response, + }), + }; +}; + +const convertToolCallMessage = (message: AssistantMessage): AssistantMessage => { + // multi-call not supported by simulated mode, there will never be more than one + const toolCall = message.toolCalls![0]; + + let content = message.content || ''; + + content += + TOOL_USE_START + + '\n```json\n' + + JSON.stringify({ + name: toolCall.function.name, + input: 'arguments' in toolCall.function ? toolCall.function.arguments : {}, + }) + + '\n```' + + TOOL_USE_END; + + return { + role: MessageRole.Assistant, + content, + }; +}; diff --git a/x-pack/plugins/inference/server/chat_complete/types.ts b/x-pack/plugins/inference/server/chat_complete/types.ts index 5ef28fdbdc8082..394fe370240efe 100644 --- a/x-pack/plugins/inference/server/chat_complete/types.ts +++ b/x-pack/plugins/inference/server/chat_complete/types.ts @@ -10,6 +10,7 @@ import type { Logger } from '@kbn/logging'; import type { ChatCompletionChunkEvent, ChatCompletionTokenCountEvent, + FunctionCallingMode, Message, } from '../../common/chat_complete'; import type { ToolOptions } from '../../common/chat_complete/tools'; @@ -24,9 +25,10 @@ import type { InferenceExecutor } from './utils'; export interface InferenceConnectorAdapter { chatComplete: ( options: { + executor: InferenceExecutor; messages: Message[]; system?: string; - executor: InferenceExecutor; + functionCalling?: FunctionCallingMode; logger: Logger; } & ToolOptions ) => Observable; diff --git a/x-pack/plugins/inference/server/routes/chat_complete.ts b/x-pack/plugins/inference/server/routes/chat_complete.ts index 5a9c0aae509586..fdf33fbf0af826 100644 --- a/x-pack/plugins/inference/server/routes/chat_complete.ts +++ b/x-pack/plugins/inference/server/routes/chat_complete.ts @@ -71,6 +71,9 @@ const chatCompleteBodySchema: Type = schema.object({ }), ]) ), + functionCalling: schema.maybe( + schema.oneOf([schema.literal('native'), schema.literal('simulated')]) + ), }); export function registerChatCompleteRoute({ @@ -96,7 +99,7 @@ export function registerChatCompleteRoute({ const client = createInferenceClient({ request, actions, logger }); - const { connectorId, messages, system, toolChoice, tools } = request.body; + const { connectorId, messages, system, toolChoice, tools, functionCalling } = request.body; const chatCompleteResponse = client.chatComplete({ connectorId, @@ -104,6 +107,7 @@ export function registerChatCompleteRoute({ system, toolChoice, tools, + functionCalling, }); return response.ok({ diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts index 8a111322a8de62..d31952e2f52520 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/generate_esql.ts @@ -21,6 +21,7 @@ import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/con import { EsqlDocumentBase } from '../doc_base'; import { requestDocumentationSchema } from './shared'; import type { NlToEsqlTaskEvent } from '../types'; +import type { FunctionCallingMode } from '../../../../common/chat_complete'; export const generateEsqlTask = ({ chatCompleteApi, @@ -29,6 +30,7 @@ export const generateEsqlTask = ({ messages, toolOptions: { tools, toolChoice }, docBase, + functionCalling, logger, }: { connectorId: string; @@ -37,6 +39,7 @@ export const generateEsqlTask = ({ toolOptions: ToolOptions; chatCompleteApi: InferenceClient['chatComplete']; docBase: EsqlDocumentBase; + functionCalling?: FunctionCallingMode; logger: Pick; }) => { return function askLlmToRespond({ @@ -65,6 +68,7 @@ export const generateEsqlTask = ({ }), chatCompleteApi({ connectorId, + functionCalling, system: `${systemMessage} # Current task diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts index 05f454c044d31c..d4eb3060f59bbf 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/actions/request_documentation.ts @@ -10,24 +10,28 @@ import { InferenceClient, withoutOutputUpdateEvents } from '../../..'; import { Message } from '../../../../common'; import { ToolChoiceType, ToolOptions } from '../../../../common/chat_complete/tools'; import { requestDocumentationSchema } from './shared'; +import type { FunctionCallingMode } from '../../../../common/chat_complete'; export const requestDocumentation = ({ outputApi, system, messages, connectorId, + functionCalling, toolOptions: { tools, toolChoice }, }: { outputApi: InferenceClient['output']; system: string; messages: Message[]; connectorId: string; + functionCalling?: FunctionCallingMode; toolOptions: ToolOptions; }) => { const hasTools = !isEmpty(tools) && toolChoice !== ToolChoiceType.none; return outputApi('request_documentation', { connectorId, + functionCalling, system, previousMessages: messages, input: `Based on the previous conversation, request documentation diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts index 29f07af2d11218..6df382a57fd615 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/doc_base/aliases.ts @@ -10,7 +10,8 @@ * This is mostly for the case for STATS. */ const aliases: Record = { - STATS: ['STATS_BY', 'BY', 'STATS...BY'], + STATS: ['STATS_BY', 'BY', 'STATS...BY', 'STATS ... BY'], + OPERATORS: ['LIKE', 'RLIKE', 'IN'], }; const getAliasMap = () => { diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/esql_docs/esql-where.txt b/x-pack/plugins/inference/server/tasks/nl_to_esql/esql_docs/esql-where.txt index b9b70ebad625e8..ccd7e12517ffba 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/esql_docs/esql-where.txt +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/esql_docs/esql-where.txt @@ -21,6 +21,7 @@ WHERE supports the following types of functions: - Type conversation functions - Conditional functions and expressions - Multi-value functions +- Operators Aggregation functions are WHERE supported for EVAL. diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/system_message.txt b/x-pack/plugins/inference/server/tasks/nl_to_esql/system_message.txt index 2efa08a6288c0e..da590d9531ccb9 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/system_message.txt +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/system_message.txt @@ -185,7 +185,6 @@ Binary operators: ==, !=, <, <=, >, >=, +, -, *, /, % Logical operators: AND, OR, NOT Predicates: IS NULL, IS NOT NULL Unary operators: - - IN LIKE: filter data based on string patterns using wildcards RLIKE: filter data based on string patterns using regular expressions diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts index 04b879351cc54f..e0c5a838ea148f 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/task.ts @@ -21,6 +21,7 @@ export function naturalLanguageToEsql({ tools, toolChoice, logger, + functionCalling, ...rest }: NlToEsqlTaskParams): Observable> { return from(loadDocBase()).pipe( @@ -36,6 +37,7 @@ export function naturalLanguageToEsql({ docBase, logger, systemMessage, + functionCalling, toolOptions: { tools, toolChoice, @@ -44,6 +46,7 @@ export function naturalLanguageToEsql({ return requestDocumentation({ connectorId, + functionCalling, outputApi: client.output, messages, system: systemMessage, diff --git a/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts b/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts index c460f029b147e6..a0bcd635081ead 100644 --- a/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts +++ b/x-pack/plugins/inference/server/tasks/nl_to_esql/types.ts @@ -9,6 +9,7 @@ import type { Logger } from '@kbn/logging'; import type { ChatCompletionChunkEvent, ChatCompletionMessageEvent, + FunctionCallingMode, Message, } from '../../../common/chat_complete'; import type { ToolOptions } from '../../../common/chat_complete/tools'; @@ -27,5 +28,6 @@ export type NlToEsqlTaskParams = { client: Pick; connectorId: string; logger: Pick; + functionCalling?: FunctionCallingMode; } & TToolOptions & ({ input: string } | { messages: Message[] }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.test.ts index 9d6c0dba0b124b..3d83c470de0c5a 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.test.ts @@ -49,6 +49,7 @@ describe('chatFunctionClient', () => { messages: [], signal: new AbortController().signal, connectorId: 'foo', + useSimulatedFunctionCalling: false, }); }).rejects.toThrowError(`Function arguments are invalid`); @@ -109,6 +110,7 @@ describe('chatFunctionClient', () => { messages: [], signal: new AbortController().signal, connectorId: 'foo', + useSimulatedFunctionCalling: false, }); expect(result).toEqual({ diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts index fa1d0e5fd669de..039d7347c715ef 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/chat_function_client/index.ts @@ -147,6 +147,7 @@ export class ChatFunctionClient { messages, signal, connectorId, + useSimulatedFunctionCalling, }: { chat: FunctionCallChatFunction; name: string; @@ -154,6 +155,7 @@ export class ChatFunctionClient { messages: Message[]; signal: AbortSignal; connectorId: string; + useSimulatedFunctionCalling: boolean; }): Promise { const fn = this.functionRegistry.get(name); @@ -172,6 +174,7 @@ export class ChatFunctionClient { screenContexts: this.screenContexts, chat, connectorId, + useSimulatedFunctionCalling, }, signal ); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts index a0accea06370b1..a3c1d72fefbab8 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.test.ts @@ -850,6 +850,7 @@ describe('Observability AI Assistant client', () => { }, }, ], + useSimulatedFunctionCalling: false, }); }); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts index f5839b76effe8c..1e995b66059c2b 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/index.ts @@ -162,7 +162,7 @@ export class ObservabilityAIAssistantClient { complete = ({ functionClient, connectorId, - simulateFunctionCalling, + simulateFunctionCalling = false, instructions: adHocInstructions = [], messages: initialMessages, signal, @@ -299,6 +299,7 @@ export class ObservabilityAIAssistantClient { disableFunctions, tracer: completeTracer, connectorId, + useSimulatedFunctionCalling: simulateFunctionCalling === true, }) ); }), diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts index da172c974e9e2c..66204c96f31cb1 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/client/operators/continue_conversation.ts @@ -54,6 +54,7 @@ function executeFunctionAndCatchError({ logger, tracer, connectorId, + useSimulatedFunctionCalling, }: { name: string; args: string | undefined; @@ -64,6 +65,7 @@ function executeFunctionAndCatchError({ logger: Logger; tracer: LangTracer; connectorId: string; + useSimulatedFunctionCalling: boolean; }): Observable { // hide token count events from functions to prevent them from // having to deal with it as well @@ -84,6 +86,7 @@ function executeFunctionAndCatchError({ signal, messages, connectorId, + useSimulatedFunctionCalling, }) ); @@ -181,6 +184,7 @@ export function continueConversation({ disableFunctions, tracer, connectorId, + useSimulatedFunctionCalling, }: { messages: Message[]; functionClient: ChatFunctionClient; @@ -197,6 +201,7 @@ export function continueConversation({ }; tracer: LangTracer; connectorId: string; + useSimulatedFunctionCalling: boolean; }): Observable { let nextFunctionCallsLeft = functionCallsLeft; @@ -310,6 +315,7 @@ export function continueConversation({ logger, tracer, connectorId, + useSimulatedFunctionCalling, }); } @@ -338,6 +344,7 @@ export function continueConversation({ disableFunctions, tracer, connectorId, + useSimulatedFunctionCalling, }); }) ) diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts index 9ae585af9071c7..ebc54daf367398 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant/server/service/types.ts @@ -55,6 +55,7 @@ type RespondFunction = ( screenContexts: ObservabilityAIAssistantScreenContextRequest[]; chat: FunctionCallChatFunction; connectorId: string; + useSimulatedFunctionCalling: boolean; }, signal: AbortSignal ) => Promise; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/functions/visualize_esql.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/functions/visualize_esql.ts index ebdfbf32abac62..499d885d1ab34d 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/functions/visualize_esql.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/common/functions/visualize_esql.ts @@ -47,7 +47,16 @@ export interface VisualizeQueryResponsev1 { }; } -export type VisualizeQueryResponse = VisualizeQueryResponsev0 | VisualizeQueryResponsev1; +export type VisualizeQueryResponsev2 = VisualizeQueryResponsev1 & { + data: { + correctedQuery: string; + }; +}; + +export type VisualizeQueryResponse = + | VisualizeQueryResponsev0 + | VisualizeQueryResponsev1 + | VisualizeQueryResponsev2; export type VisualizeESQLFunctionArguments = FromSchema< (typeof visualizeESQLFunction)['parameters'] diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/functions/visualize_esql.tsx b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/functions/visualize_esql.tsx index 404ff9e32a4db8..e1889c7bc199ab 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/functions/visualize_esql.tsx +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/public/functions/visualize_esql.tsx @@ -419,6 +419,11 @@ export function registerVisualizeQueryRenderFunction({ ? typedResponse.content.errorMessages : []; + const correctedQuery = + 'data' in typedResponse && 'correctedQuery' in typedResponse.data + ? typedResponse.data.correctedQuery + : query; + if ('data' in typedResponse && 'userOverrides' in typedResponse.data) { userOverrides = typedResponse.data.userOverrides; } @@ -472,7 +477,7 @@ export function registerVisualizeQueryRenderFunction({ break; } - const trimmedQuery = query.trim(); + const trimmedQuery = correctedQuery.trim(); return ( { + const correctedQuery = correctCommonEsqlMistakes(query).output; + const client = (await resources.context.core).elasticsearch.client.asCurrentUser; const { error, errorMessages, rows, columns } = await runAndValidateEsqlQuery({ - query, + query: correctedQuery, client, }); @@ -108,7 +114,7 @@ export function registerQueryFunction({ function takes no input.`, visibility: FunctionVisibility.AssistantOnly, }, - async ({ messages, connectorId }, signal) => { + async ({ messages, connectorId, useSimulatedFunctionCalling }, signal) => { const esqlFunctions = functions .getFunctions() .filter( @@ -132,6 +138,7 @@ export function registerQueryFunction({ .concat(esqlFunctions) .map((fn) => [fn.name, { description: fn.description, schema: fn.parameters }]) ), + functionCalling: useSimulatedFunctionCalling ? 'simulated' : 'native', }); const chatMessageId = v4(); diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/validate_esql_query.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/validate_esql_query.ts index ac26846f940e6b..1c36d085945217 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/validate_esql_query.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/query/validate_esql_query.ts @@ -25,16 +25,20 @@ export async function runAndValidateEsqlQuery({ error?: Error; errorMessages?: string[]; }> { - const { errors } = await validateQuery(query, getAstAndSyntaxErrors, { + const queryWithoutLineBreaks = query.replaceAll(/\n/g, ''); + + const { errors } = await validateQuery(queryWithoutLineBreaks, getAstAndSyntaxErrors, { // setting this to true, we don't want to validate the index / fields existence ignoreOnMissingCallbacks: true, }); - const asCommands = splitIntoCommands(query); + const asCommands = splitIntoCommands(queryWithoutLineBreaks); const errorMessages = errors?.map((error) => { if ('location' in error) { - const commandsUntilEndOfError = splitIntoCommands(query.substring(0, error.location.max)); + const commandsUntilEndOfError = splitIntoCommands( + queryWithoutLineBreaks.substring(0, error.location.max) + ); const lastCompleteCommand = asCommands[commandsUntilEndOfError.length - 1]; if (lastCompleteCommand) { return `Error in ${lastCompleteCommand.command}\n: ${error.text}`; diff --git a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/visualize_esql.ts b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/visualize_esql.ts index bca5b04e2da06d..4eeba0450e6e4e 100644 --- a/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/visualize_esql.ts +++ b/x-pack/plugins/observability_solution/observability_ai_assistant_app/server/functions/visualize_esql.ts @@ -5,9 +5,10 @@ * 2.0. */ import { VisualizeESQLUserIntention } from '@kbn/observability-ai-assistant-plugin/common/functions/visualize_esql'; +import { correctCommonEsqlMistakes } from '@kbn/inference-plugin/common'; import { visualizeESQLFunction, - type VisualizeQueryResponsev1, + VisualizeQueryResponsev2, } from '../../common/functions/visualize_esql'; import type { FunctionRegistrationParameters } from '.'; import { runAndValidateEsqlQuery } from './query/validate_esql_query'; @@ -32,12 +33,15 @@ export function registerVisualizeESQLFunction({ }: FunctionRegistrationParameters) { functions.registerFunction( visualizeESQLFunction, - async ({ arguments: { query, intention } }): Promise => { + async ({ arguments: { query, intention } }): Promise => { // errorMessages contains the syntax errors from the client side valdation // error contains the error from the server side validation, it is always one error // and help us identify errors like index not found, field not found etc. + + const correctedQuery = correctCommonEsqlMistakes(query).output; + const { columns, errorMessages, rows, error } = await runAndValidateEsqlQuery({ - query, + query: correctedQuery, client: (await resources.context.core).elasticsearch.client.asCurrentUser, }); @@ -47,6 +51,7 @@ export function registerVisualizeESQLFunction({ data: { columns: columns ?? [], rows: rows ?? [], + correctedQuery, }, content: { message,