Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* 2.0.
*/

import { isImpliedDefaultElserInferenceId } from './is_default_inference_endpoint';
import { type ProductName, DocumentationProduct } from './product';

const allowedProductNames: ProductName[] = Object.values(DocumentationProduct);
Expand All @@ -24,7 +25,7 @@ export const getArtifactName = ({
}): string => {
const ext = excludeExtension ? '' : '.zip';
return `kb-product-doc-${productName}-${productVersion}${
inferenceId && inferenceId !== DEFAULT_ELSER ? `--${inferenceId}` : ''
inferenceId && !isImpliedDefaultElserInferenceId(inferenceId) ? `--${inferenceId}` : ''
}${ext}`.toLowerCase();
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export const isImpliedDefaultElserInferenceId = (inferenceId: string | null | un
inferenceId === null ||
inferenceId === undefined ||
inferenceId === defaultInferenceEndpoints.ELSER ||
inferenceId === defaultInferenceEndpoints.ELSER_IN_EIS_INFERENCE_ID
inferenceId === defaultInferenceEndpoints.ELSER_IN_EIS_INFERENCE_ID ||
(typeof inferenceId === 'string' && inferenceId.toLowerCase().includes('elser'))
);
};
6 changes: 3 additions & 3 deletions x-pack/platform/plugins/shared/ai_infra/llm_tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ context.

That API receive the inbound request as parameter.

Example:
Example, by default it will check with the default ELSER model:
```ts
if (await llmTasksStart.retrieveDocumentationAvailable({ request })) {
if (await llmTasksStart.retrieveDocumentationAvailable({ inferenceId })) {
// task is available
} else {
// task is not available
}
```

### Executing the task

Expand All @@ -37,6 +36,7 @@ const result = await llmTasksStart.retrieveDocumentation({
searchTerm: "How to create a space in Kibana?",
request,
connectorId: 'my-connector-id',
inferenceId: 'my-inference-id',
});

const { success, documents } = result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import type { Logger } from '@kbn/logging';
import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/server';
import { defaultInferenceEndpoints } from '@kbn/inference-common';
import type { LlmTasksConfig } from './config';
import type {
LlmTasksPluginSetup,
Expand Down Expand Up @@ -41,9 +40,9 @@ export class LlmTasksPlugin
start(core: CoreStart, startDependencies: PluginStartDependencies): LlmTasksPluginStart {
const { inference, productDocBase } = startDependencies;
return {
retrieveDocumentationAvailable: async () => {
retrieveDocumentationAvailable: async (options: { inferenceId: string }) => {
const docBaseStatus = await startDependencies.productDocBase.management.getStatus({
inferenceId: defaultInferenceEndpoints.ELSER,
inferenceId: options.inferenceId,
});
return docBaseStatus.status === 'installed';
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export interface LlmTasksPluginStart {
* are respected. Can be used to check if the task can be registered
* as LLM tool for example.
*/
retrieveDocumentationAvailable: () => Promise<boolean>;
retrieveDocumentationAvailable: (options: { inferenceId: string }) => Promise<boolean>;
/**
* Perform the `retrieveDocumentation` task.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { defaultInferenceEndpoints } from '@kbn/inference-common';
import { cloneDeep } from 'lodash';
import type { InferenceInferenceEndpointInfo } from '@elastic/elasticsearch/lib/api/types';
import { i18n } from '@kbn/i18n';
import { isImpliedDefaultElserInferenceId } from '@kbn/product-doc-common/src/is_default_inference_endpoint';
import type { ProductDocInstallClient } from '../doc_install_status';
import {
downloadToDisk,
Expand Down Expand Up @@ -179,7 +180,7 @@ export class PackageInstaller {
inferenceId,
});

if (customInference && customInference?.inference_id !== this.elserInferenceId) {
if (customInference && !isImpliedDefaultElserInferenceId(customInference?.inference_id)) {
if (customInference?.task_type !== 'text_embedding') {
throw new Error(
`Inference [${inferenceId}]'s task type ${customInference?.task_type} is not supported. Please use a model with task type 'text_embedding'.`
Expand All @@ -191,7 +192,7 @@ export class PackageInstaller {
});
}

if (!customInference || customInference?.inference_id === this.elserInferenceId) {
if (!customInference || isImpliedDefaultElserInferenceId(customInference?.inference_id)) {
await ensureDefaultElserDeployed({
client: this.esClient,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ export async function registerDocumentationFunction({
resources,
pluginsStart: { llmTasks },
}: FunctionRegistrationParameters) {
const isProductDocAvailable = (await llmTasks.retrieveDocumentationAvailable()) ?? false;
const esClient = (await resources.context.core).elasticsearch.client;
const inferenceId =
(await getInferenceIdFromWriteIndex(esClient)) ?? defaultInferenceEndpoints.ELSER;
const isProductDocAvailable =
(await llmTasks.retrieveDocumentationAvailable({ inferenceId })) ?? false;

if (isProductDocAvailable) {
functions.registerInstruction(({ availableFunctionNames }) => {
Expand Down Expand Up @@ -66,11 +70,6 @@ export async function registerDocumentationFunction({
} as const,
},
async ({ arguments: { query, product }, connectorId, simulateFunctionCalling }) => {
const esClient = (await resources.context.core).elasticsearch.client;

const inferenceId =
(await getInferenceIdFromWriteIndex(esClient)) ?? defaultInferenceEndpoints.ELSER;

const response = await llmTasks!.retrieveDocumentation({
searchTerm: query,
products: product ? [product] : undefined,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
} from '@kbn/elastic-assistant-common';
import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common';
import { getRequestAbortedSignal } from '@kbn/data-plugin/server';
import { defaultInferenceEndpoints } from '@kbn/inference-common';
import { INVOKE_ASSISTANT_ERROR_EVENT } from '../../lib/telemetry/event_based_telemetry';
import { ElasticAssistantPluginRouter } from '../../types';
import { buildResponse } from '../../lib/build_response';
Expand Down Expand Up @@ -86,7 +87,9 @@ export const chatCompleteRoute = (
telemetry = ctx.elasticAssistant.telemetry;
const inference = ctx.elasticAssistant.inference;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable({
inferenceId: defaultInferenceEndpoints.ELSER,
})) ?? false;

// Perform license and authenticated user checks
const checkResponse = await performChecks({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { getDefaultArguments } from '@kbn/langchain/server';
import { StructuredTool } from '@langchain/core/tools';
import { AgentFinish } from 'langchain/agents';
import { omit } from 'lodash/fp';
import { defaultInferenceEndpoints } from '@kbn/inference-common';
import { localToolPrompts, promptGroupId as toolsGroupId } from '../../lib/prompt/tool_prompts';
import { promptGroupId } from '../../lib/prompt/local_prompt_object';
import { getFormattedTime, getModelOrOss } from '../../lib/prompt/helpers';
Expand Down Expand Up @@ -173,7 +174,9 @@ export const postEvaluateRoute = (

const inference = ctx.elasticAssistant.inference;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable({
inferenceId: defaultInferenceEndpoints.ELSER,
})) ?? false;

const { featureFlags } = await context.core;
const inferenceChatModelDisabled = await featureFlags.getBooleanValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG,
} from '@kbn/elastic-assistant-common';
import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common';
import { defaultInferenceEndpoints } from '@kbn/inference-common';
import { getPrompt } from '../lib/prompt';
import { INVOKE_ASSISTANT_ERROR_EVENT } from '../lib/telemetry/event_based_telemetry';
import { buildResponse } from '../lib/build_response';
Expand Down Expand Up @@ -124,7 +125,9 @@ export const postActionsConnectorExecuteRoute = (
const inference = ctx.elasticAssistant.inference;
const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable({
inferenceId: defaultInferenceEndpoints.ELSER,
})) ?? false;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
const connector = connectors.length > 0 ? connectors[0] : undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
import { ChatCompletionMessageParam } from 'openai/resources';
import { last } from 'lodash';
import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
import { TINY_ELSER_INFERENCE_ID } from '../../utils/model_and_inference';
import { LlmProxy, createLlmProxy } from '../../utils/create_llm_proxy';
import { chatComplete } from '../../utils/conversation';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context';
import { installProductDoc, uninstallProductDoc } from '../../utils/product_doc_base';

const DEFAULT_INFERENCE_ID = '.elser-2-elasticsearch';
import {
deployTinyElserAndSetupKb,
teardownTinyElserModelAndInferenceEndpoint,
} from '../../utils/model_and_inference';

export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) {
const log = getService('log');
Expand Down Expand Up @@ -86,14 +89,19 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let llmProxy: LlmProxy;
let connectorId: string;
let messageAddedEvents: MessageAddEvent[];
let firstRequestBody: ChatCompletionStreamParams;
let secondRequestBody: ChatCompletionStreamParams;
let toolCallRequestBody: ChatCompletionStreamParams;
let userPromptRequestBody: ChatCompletionStreamParams;

before(async () => {
llmProxy = await createLlmProxy(log);
connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({
port: llmProxy.getPort(),
});
await installProductDoc(supertest, DEFAULT_INFERENCE_ID);
await deployTinyElserAndSetupKb(getService);

await installProductDoc(supertest, TINY_ELSER_INFERENCE_ID);

void llmProxy.interceptQueryRewrite('This is a rewritten user prompt.');

void llmProxy.interceptWithFunctionRequest({
name: 'retrieve_elastic_doc',
Expand All @@ -113,50 +121,49 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
}));

await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
firstRequestBody = llmProxy.interceptedRequests[0].requestBody;
secondRequestBody = llmProxy.interceptedRequests[1].requestBody;
toolCallRequestBody = llmProxy.interceptedRequests[1].requestBody;
userPromptRequestBody = llmProxy.interceptedRequests[2].requestBody;
});

after(async () => {
await uninstallProductDoc(supertest, DEFAULT_INFERENCE_ID);
await uninstallProductDoc(supertest, TINY_ELSER_INFERENCE_ID);
llmProxy.close();
await observabilityAIAssistantAPIClient.deleteActionConnector({
actionId: connectorId,
});
await teardownTinyElserModelAndInferenceEndpoint(getService);
});

it('makes 2 requests to the LLM', () => {
expect(llmProxy.interceptedRequests.length).to.be(2);
it('makes 3 requests to the LLM', () => {
expect(llmProxy.interceptedRequests.length).to.be(3);
});

it('emits 5 messageAdded events', () => {
expect(messageAddedEvents.length).to.be(5);
});

describe('The first request', () => {
it('contains the retrieve_elastic_doc function', () => {
expect(firstRequestBody.tools?.map((t) => t.function.name)).to.contain(
it('enables the LLM to call `retrieve_elastic_doc`', () => {
expect(toolCallRequestBody.tool_choice).to.be('auto');
expect(toolCallRequestBody.tools?.map((t) => t.function.name)).to.contain(
'retrieve_elastic_doc'
);
});

it('leaves the LLM to choose the correct tool by leave tool_choice as auto and passes tools', () => {
expect(firstRequestBody.tool_choice).to.be('auto');
expect(firstRequestBody.tools?.length).to.not.be(0);
});
});

describe('The second request - Sending the user prompt', () => {
let lastMessage: ChatCompletionMessageParam;
let parsedContent: { documents: Array<{ title: string; content: string; url: string }> };

before(() => {
lastMessage = last(secondRequestBody.messages) as ChatCompletionMessageParam;
lastMessage = last(userPromptRequestBody.messages) as ChatCompletionMessageParam;
parsedContent = JSON.parse(lastMessage.content as string);
});

it('includes the retrieve_elastic_doc function call', () => {
expect(secondRequestBody.messages[4].role).to.be(MessageRole.Assistant);
expect(userPromptRequestBody.messages[4].role).to.be(MessageRole.Assistant);
// @ts-expect-error
expect(secondRequestBody.messages[4].tool_calls[0].function.name).to.be(
expect(userPromptRequestBody.messages[4].tool_calls[0].function.name).to.be(
'retrieve_elastic_doc'
);
});
Expand All @@ -166,9 +173,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
// @ts-expect-error
expect(lastMessage?.tool_call_id).to.equal(
// @ts-expect-error
secondRequestBody.messages[4].tool_calls[0].id
userPromptRequestBody.messages[4].tool_calls[0].id
);
});

it('sends the retrieved documents from Elastic docs to the LLM', () => {
expect(lastMessage.content).to.be.a('string');
});
Expand Down