Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@ import {
Message,
MessageRole,
OutputCompleteEvent,
OutputEventType,
ChatCompleteMetadata,
ChatCompleteOptions,
ChatCompleteAPI,
OutputEventType,
ToolChoiceType,
} from '@kbn/inference-common';
import { correctCommonEsqlMistakes, generateFakeToolCallId } from '../../../../common';
import { INLINE_ESQL_QUERY_REGEX } from '../../../../common/tasks/nl_to_esql/constants';
import { EsqlDocumentBase } from '../doc_base';
import { requestDocumentationSchema } from './shared';
import type { NlToEsqlTaskEvent } from '../types';

const MAX_CALLS = 5;

export const generateEsqlTask = <TToolOptions extends ToolOptions>({
chatCompleteApi,
connectorId,
Expand All @@ -39,6 +42,7 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
logger,
system,
metadata,
maxCallsAllowed = MAX_CALLS,
}: {
connectorId: string;
systemMessage: string;
Expand All @@ -49,12 +53,16 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
logger: Pick<Logger, 'debug'>;
metadata?: ChatCompleteMetadata;
system?: string;
maxCallsAllowed?: number;
} & Pick<ChatCompleteOptions, 'maxRetries' | 'retryConfiguration' | 'functionCalling'>) => {
return function askLlmToRespond({
documentationRequest: { commands, functions },
callCount = 0,
}: {
documentationRequest: { commands?: string[]; functions?: string[] };
callCount?: number;
}): Observable<NlToEsqlTaskEvent<TToolOptions>> {
const functionLimitReached = callCount >= maxCallsAllowed;
const keywords = [...(commands ?? []), ...(functions ?? [])];
const requestedDocumentation = docBase.getDocumentation(keywords);
const fakeRequestDocsToolCall = createFakeTooCall(commands, functions);
Expand Down Expand Up @@ -123,14 +131,16 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
toolCallId: fakeRequestDocsToolCall.toolCallId,
},
],
toolChoice,
tools: {
...tools,
request_documentation: {
description: 'Request additional ES|QL documentation if needed',
schema: requestDocumentationSchema,
},
},
toolChoice: !functionLimitReached ? toolChoice : ToolChoiceType.none,
tools: functionLimitReached
? {}
: {
...tools,
request_documentation: {
description: 'Request additional ES|QL documentation if needed',
schema: requestDocumentationSchema,
},
},
}).pipe(
withoutTokenCountEvents(),
map((generateEvent) => {
Expand All @@ -147,18 +157,32 @@ export const generateEsqlTask = <TToolOptions extends ToolOptions>({
}),
switchMap((generateEvent) => {
if (isChatCompletionMessageEvent(generateEvent)) {
const onlyToolCall =
generateEvent.toolCalls.length === 1 ? generateEvent.toolCalls[0] : undefined;

if (onlyToolCall?.function.name === 'request_documentation') {
const args = onlyToolCall.function.arguments;

return askLlmToRespond({
documentationRequest: {
commands: args.commands,
functions: args.functions,
},
});
const toolCalls = generateEvent.toolCalls as ToolCall[];
const onlyToolCall = toolCalls.length === 1 ? toolCalls[0] : undefined;

if (onlyToolCall && onlyToolCall.function.name === 'request_documentation') {
if (functionLimitReached) {
return of({
...generateEvent,
content: `You have reached the maximum number of documentation requests. Do not try to request documentation again for commands ${commands?.join(
', '
)} and functions ${functions?.join(
', '
)}. Try to answer the user's question using currently available information.`,
});
}

const args =
'arguments' in onlyToolCall.function ? onlyToolCall.function.arguments : undefined;
if (args && (args.commands?.length || args.functions?.length)) {
return askLlmToRespond({
documentationRequest: {
commands: args.commands ?? [],
functions: args.functions ?? [],
},
callCount: callCount + 1,
});
}
}
}

Expand Down