diff --git a/x-pack/platform/plugins/shared/inference/server/tasks/nl_to_esql/actions/generate_esql.ts b/x-pack/platform/plugins/shared/inference/server/tasks/nl_to_esql/actions/generate_esql.ts index a8091da5c9f67..1f7b3dd14ec09 100644 --- a/x-pack/platform/plugins/shared/inference/server/tasks/nl_to_esql/actions/generate_esql.ts +++ b/x-pack/platform/plugins/shared/inference/server/tasks/nl_to_esql/actions/generate_esql.ts @@ -15,10 +15,11 @@ import { Message, MessageRole, OutputCompleteEvent, - OutputEventType, ChatCompleteMetadata, ChatCompleteOptions, ChatCompleteAPI, + OutputEventType, + ToolChoiceType, } from '@kbn/inference-common'; import { correctCommonEsqlMistakes, generateFakeToolCallId } from '../../../../common'; import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/constants'; @@ -26,6 +27,8 @@ import { EsqlDocumentBase } from '../doc_base'; import { requestDocumentationSchema } from './shared'; import type { NlToEsqlTaskEvent } from '../types'; +const MAX_CALLS = 5; + export const generateEsqlTask = ({ chatCompleteApi, connectorId, @@ -39,6 +42,7 @@ export const generateEsqlTask = ({ logger, system, metadata, + maxCallsAllowed = MAX_CALLS, }: { connectorId: string; systemMessage: string; @@ -49,12 +53,16 @@ export const generateEsqlTask = ({ logger: Pick; metadata?: ChatCompleteMetadata; system?: string; + maxCallsAllowed?: number; } & Pick) => { return function askLlmToRespond({ documentationRequest: { commands, functions }, + callCount = 0, }: { documentationRequest: { commands?: string[]; functions?: string[] }; + callCount?: number; }): Observable> { + const functionLimitReached = callCount >= maxCallsAllowed; const keywords = [...(commands ?? []), ...(functions ?? [])]; const requestedDocumentation = docBase.getDocumentation(keywords); const fakeRequestDocsToolCall = createFakeTooCall(commands, functions); @@ -123,14 +131,16 @@ export const generateEsqlTask = ({ toolCallId: fakeRequestDocsToolCall.toolCallId, }, ], - toolChoice, - tools: { - ...tools, - request_documentation: { - description: 'Request additional ES|QL documentation if needed', - schema: requestDocumentationSchema, - }, - }, + toolChoice: !functionLimitReached ? toolChoice : ToolChoiceType.none, + tools: functionLimitReached + ? {} + : { + ...tools, + request_documentation: { + description: 'Request additional ES|QL documentation if needed', + schema: requestDocumentationSchema, + }, + }, }).pipe( withoutTokenCountEvents(), map((generateEvent) => { @@ -147,18 +157,32 @@ export const generateEsqlTask = ({ }), switchMap((generateEvent) => { if (isChatCompletionMessageEvent(generateEvent)) { - const onlyToolCall = - generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined; - - if (onlyToolCall?.function.name === 'request_documentation') { - const args = onlyToolCall.function.arguments; - - return askLlmToRespond({ - documentationRequest: { - commands: args.commands, - functions: args.functions, - }, - }); + const toolCalls = generateEvent.toolCalls as ToolCall[]; + const onlyToolCall = toolCalls.length === 1 ? toolCalls[0] : undefined; + + if (onlyToolCall && onlyToolCall.function.name === 'request_documentation') { + if (functionLimitReached) { + return of({ + ...generateEvent, + content: `You have reached the maximum number of documentation requests. Do not try to request documentation again for commands ${commands?.join( + ', ' + )} and functions ${functions?.join( + ', ' + )}. Try to answer the user's question using currently available information.`, + }); + } + + const args = + 'arguments' in onlyToolCall.function ? onlyToolCall.function.arguments : undefined; + if (args && (args.commands?.length || args.functions?.length)) { + return askLlmToRespond({ + documentationRequest: { + commands: args.commands ?? [], + functions: args.functions ?? [], + }, + callCount: callCount + 1, + }); + } } }